diff --git a/pymongo/member.py b/pymongo/member.py new file mode 100644 index 000000000..34862de46 --- /dev/null +++ b/pymongo/member.py @@ -0,0 +1,141 @@ +# Copyright 2013 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Represent a mongod / mongos instance""" + +from pymongo import common +from pymongo.read_preferences import ReadPreference + +# Member states +PRIMARY = 1 +SECONDARY = 2 +ARBITER = 3 +OTHER = 4 + + +# TODO: rename 'Server' or 'ServerDescription'. +class Member(object): + """Immutable representation of one server. + + :Parameters: + - `host`: A (host, port) pair + - `connection_pool`: A Pool instance + - `ismaster_response`: A dict, MongoDB's ismaster response + - `ping_time`: A MovingAverage instance + - `up`: Whether we think this member is available + """ + # For unittesting only. Use under no circumstances! + _host_to_ping_time = {} + + def __init__(self, host, connection_pool, ismaster_response, ping_time, up): + self.host = host + self.pool = connection_pool + self.ismaster_response = ismaster_response + self.ping_time = ping_time + self.up = up + self.is_mongos = (ismaster_response.get('msg') == 'isdbgrid') + + if ismaster_response['ismaster']: + self.state = PRIMARY + elif ismaster_response.get('secondary'): + self.state = SECONDARY + elif ismaster_response.get('arbiterOnly'): + self.state = ARBITER + else: + self.state = OTHER + + self.tags = ismaster_response.get('tags', {}) + self.max_bson_size = ismaster_response.get( + 'maxBsonObjectSize', common.MAX_BSON_SIZE) + self.max_message_size = ismaster_response.get( + 'maxMessageSizeBytes', 2 * self.max_bson_size) + self.min_wire_version = ismaster_response.get( + 'minWireVersion', common.MIN_WIRE_VERSION) + self.max_wire_version = ismaster_response.get( + 'maxWireVersion', common.MAX_WIRE_VERSION) + + self.set_name = ismaster_response.get('setName') + + def clone_with(self, ismaster_response, ping_time_sample): + """Get a clone updated with ismaster response and a single ping time. + """ + ping_time = self.ping_time.clone_with(ping_time_sample) + return Member(self.host, self.pool, ismaster_response, ping_time, True) + + def clone_down(self): + """Get a clone of this Member, but with up=False. + """ + return Member( + self.host, self.pool, self.ismaster_response, self.ping_time, + False) + + @property + def is_primary(self): + return self.state == PRIMARY + + @property + def is_secondary(self): + return self.state == SECONDARY + + @property + def is_arbiter(self): + return self.state == ARBITER + + def get_avg_ping_time(self): + """Get a moving average of this member's ping times. + """ + if self.host in Member._host_to_ping_time: + # Simulate ping times for unittesting + return Member._host_to_ping_time[self.host] + + return self.ping_time.get() + + def matches_mode(self, mode): + assert not self.is_mongos, \ + "Tried to match read preference mode on a mongos Member" + + if mode == ReadPreference.PRIMARY and not self.is_primary: + return False + + if mode == ReadPreference.SECONDARY and not self.is_secondary: + return False + + # If we're not primary or secondary, then we're in a state like + # RECOVERING and we don't match any mode + return self.is_primary or self.is_secondary + + def matches_tags(self, tags): + """Return True if this member's tags are a superset of the passed-in + tags. E.g., if this member is tagged {'dc': 'ny', 'rack': '1'}, + then it matches {'dc': 'ny'}. + """ + for key, value in tags.items(): + if key not in self.tags or self.tags[key] != value: + return False + + return True + + def matches_tag_sets(self, tag_sets): + """Return True if this member matches any of the tag sets, e.g. + [{'dc': 'ny'}, {'dc': 'la'}, {}] + """ + for tags in tag_sets: + if self.matches_tags(tags): + return True + + return False + + def __str__(self): + return '' % ( + self.host[0], self.host[1], self.is_primary, self.up) diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 6ffb73c34..4ad49c762 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -38,6 +38,7 @@ import datetime import random import socket import struct +import threading import time import warnings @@ -48,6 +49,7 @@ from pymongo import (auth, helpers, message, pool, + thread_util, uri_parser) from pymongo.common import HAS_SSL from pymongo.cursor_manager import CursorManager @@ -58,7 +60,7 @@ from pymongo.errors import (AutoReconnect, InvalidDocument, InvalidURI, OperationFailure) - +from pymongo.member import Member EMPTY = b("") @@ -83,12 +85,6 @@ class MongoClient(common.BaseObject): HOST = "localhost" PORT = 27017 - __max_bson_size = common.MAX_BSON_SIZE - __max_message_size = common.MAX_MESSAGE_SIZE - __min_wire_version = common.MIN_WIRE_VERSION - # TODO: write commands with _connect=False - __max_wire_version = common.MAX_WIRE_VERSION - def __init__(self, host=None, port=None, max_pool_size=100, document_class=dict, tz_aware=False, _connect=True, **kwargs): @@ -260,15 +256,19 @@ class MongoClient(common.BaseObject): if not seeds: raise ConfigurationError("need to specify at least one host") - self.__nodes = seeds - self.__host = None - self.__port = None - self.__is_primary = False - self.__is_mongos = False + # Seeds are only used before first connection attempt; nodes are then + # used for any reconnects. Nodes are set to all replica set members + # if connecting to a replica set (besides arbiters), or to all + # available mongoses from the seed list, or to the one standalone + # mongod. + self.__seeds = frozenset(seeds) + self.__nodes = frozenset() + self.__member = None # TODO: Rename to __server. - # _pool_class option is for deep customization of PyMongo, e.g. Motor. - # SHOULD NOT BE USED BY DEVELOPERS EXTERNAL TO 10GEN. + # _pool_class and _event_class are for deep customization of PyMongo, + # e.g. Motor. SHOULD NOT BE USED BY THIRD-PARTY DEVELOPERS. pool_class = kwargs.pop('_pool_class', pool.Pool) + event_class = kwargs.pop('_event_class', None) options = {} for option, value in kwargs.iteritems(): @@ -282,11 +282,7 @@ class MongoClient(common.BaseObject): self.__cursor_manager = CursorManager(self) self.__repl = options.get('replicaset') - if len(seeds) == 1 and not self.__repl: - self.__direct = True - else: - self.__direct = False - self.__nodes = set() + self.__direct = len(seeds) == 1 and not self.__repl self.__net_timeout = options.get('sockettimeoutms') self.__conn_timeout = options.get('connecttimeoutms') @@ -323,20 +319,23 @@ class MongoClient(common.BaseObject): "from PyPI.") self.__use_greenlets = options.get('use_greenlets', False) - self.__pool = pool_class( - None, - self.__max_pool_size, - self.__net_timeout, - self.__conn_timeout, - self.__use_ssl, - use_greenlets=self.__use_greenlets, - ssl_keyfile=self.__ssl_keyfile, - ssl_certfile=self.__ssl_certfile, - ssl_cert_reqs=self.__ssl_cert_reqs, - ssl_ca_certs=self.__ssl_ca_certs, - wait_queue_timeout=self.__wait_queue_timeout, - wait_queue_multiple=self.__wait_queue_multiple) + self.__pool_class = pool_class + self.__connecting = False + if self.__use_greenlets: + # Greenlets don't need to lock around access to the Member; + # they're only interrupted when they do I/O. + self.__connecting_lock = thread_util.DummyLock() + else: + self.__connecting_lock = threading.Lock() + + if event_class: + self.__event_class = event_class + else: + event_class = lambda: thread_util.create_event(self.__use_greenlets) + self.__event_class = event_class + + self.__future_member = None self.__document_class = document_class self.__tz_aware = common.validate_boolean('tz_aware', tz_aware) self.__auto_start_request = options.get('auto_start_request', False) @@ -353,7 +352,7 @@ class MongoClient(common.BaseObject): if _connect: try: - self.__find_node(seeds) + self._ensure_connected(True) except AutoReconnect, e: # ConnectionFailure makes more sense here than AutoReconnect raise ConnectionFailure(str(e)) @@ -441,14 +440,15 @@ class MongoClient(common.BaseObject): 'to this database. You must logout first.') if connect: - sock_info = self.__socket() + member = self.__ensure_member() + sock_info = self.__socket(member) try: # Since __check_auth was called in __socket # there is no need to call it here. auth.authenticate(credentials, sock_info, self.__simple_command) sock_info.authset.add(credentials) finally: - self.__pool.maybe_return_socket(sock_info) + member.pool.maybe_return_socket(sock_info) self.__auth_credentials[source] = credentials @@ -458,6 +458,21 @@ class MongoClient(common.BaseObject): if source in self.__auth_credentials: del self.__auth_credentials[source] + def __create_pool(self, pair): + return self.__pool_class( + pair, + self.__max_pool_size, + self.__net_timeout, + self.__conn_timeout, + self.__use_ssl, + use_greenlets=self.__use_greenlets, + ssl_keyfile=self.__ssl_keyfile, + ssl_certfile=self.__ssl_certfile, + ssl_cert_reqs=self.__ssl_cert_reqs, + ssl_ca_certs=self.__ssl_ca_certs, + wait_queue_timeout=self.__wait_queue_timeout, + wait_queue_multiple=self.__wait_queue_multiple) + def __check_auth(self, sock_info): """Authenticate using cached database credentials. """ @@ -476,6 +491,13 @@ class MongoClient(common.BaseObject): sock_info, self.__simple_command) sock_info.authset.add(credentials) + def __member_property(self, attr_name, default=None): + member = self.__member + if member: + return getattr(member, attr_name) + + return default + @property def host(self): """Current connected host. @@ -483,7 +505,11 @@ class MongoClient(common.BaseObject): .. versionchanged:: 1.3 ``host`` is now a property rather than a method. """ - return self.__host + member = self.__member + if member: + return member.host[0] + + return None @property def port(self): @@ -492,8 +518,11 @@ class MongoClient(common.BaseObject): .. versionchanged:: 1.3 ``port`` is now a property rather than a method. """ - return self.__port + member = self.__member + if member: + return member.host[1] + return None @property def is_primary(self): """If this instance is connected to a standalone, a replica set @@ -501,7 +530,7 @@ class MongoClient(common.BaseObject): .. versionadded:: 2.3 """ - return self.__is_primary + return self.__member_property('is_primary', False) @property def is_mongos(self): @@ -509,7 +538,7 @@ class MongoClient(common.BaseObject): .. versionadded:: 2.3 """ - return self.__is_mongos + return self.__member_property('is_mongos', False) @property def max_pool_size(self): @@ -543,9 +572,8 @@ class MongoClient(common.BaseObject): def nodes(self): """List of all known nodes. - Includes both nodes specified when this instance was created, - as well as nodes discovered through the replica set discovery - mechanism. + Nodes are either specified when this instance was created, + or discovered through the replica set discovery mechanism. .. versionadded:: 1.8 """ @@ -586,7 +614,7 @@ class MongoClient(common.BaseObject): .. versionadded:: 1.10 """ - return self.__max_bson_size + return self.__member_property('max_bson_size', common.MAX_BSON_SIZE) @property def max_message_size(self): @@ -596,7 +624,8 @@ class MongoClient(common.BaseObject): .. versionadded:: 2.6 """ - return self.__max_message_size + return self.__member_property( + 'max_message_size', common.MAX_MESSAGE_SIZE) @property def min_wire_version(self): @@ -606,7 +635,8 @@ class MongoClient(common.BaseObject): .. versionadded:: 2.7 """ - return self.__min_wire_version + return self.__member_property( + 'min_wire_version', common.MIN_WIRE_VERSION) @property def max_wire_version(self): @@ -616,7 +646,8 @@ class MongoClient(common.BaseObject): .. versionadded:: 2.7 """ - return self.__max_wire_version + return self.__member_property( + 'max_wire_version', common.MAX_WIRE_VERSION) def __simple_command(self, sock_info, dbname, spec): """Send a command to the server. @@ -638,54 +669,48 @@ class MongoClient(common.BaseObject): def __try_node(self, node): """Try to connect to this node and see if it works for our connection - type. Returns ((host, port), ismaster, isdbgrid, res_time). + type. Returns a Member and set of hosts (including this one). Doesn't + modify state. :Parameters: - `node`: The (host, port) pair to try. """ - self.disconnect() - self.__host, self.__port = node - # Call 'ismaster' directly so we can get a response time. - sock_info = self.__socket() + connection_pool = self.__create_pool(node) + sock_info = connection_pool.get_socket() try: response, res_time = self.__simple_command(sock_info, 'admin', {'ismaster': 1}) finally: - self.__pool.maybe_return_socket(sock_info) + connection_pool.maybe_return_socket(sock_info) - # Are we talking to a mongos? - isdbgrid = response.get('msg', '') == 'isdbgrid' + member = Member( + node, + connection_pool, + response, + res_time, + True) - if "maxBsonObjectSize" in response: - self.__max_bson_size = response["maxBsonObjectSize"] - if "maxMessageSizeBytes" in response: - self.__max_message_size = response["maxMessageSizeBytes"] - if "minWireVersion" in response: - self.__min_wire_version = response["minWireVersion"] - if "maxWireVersion" in response: - self.__max_wire_version = response["maxWireVersion"] + nodes = frozenset([node]) # Replica Set? if not self.__direct: # Check that this host is part of the given replica set. - if self.__repl: - set_name = response.get('setName') - if set_name != self.__repl: - raise ConfigurationError("%s:%d is not a member of " - "replica set %s" - % (node[0], node[1], self.__repl)) + if self.__repl and member.set_name != self.__repl: + raise ConfigurationError("%s:%d is not a member of " + "replica set %s" + % (node[0], node[1], self.__repl)) + if "hosts" in response: - self.__nodes = set([_partition_node(h) - for h in response["hosts"]]) - else: - # The user passed a seed list of standalone or - # mongos instances. - self.__nodes.add(node) - if response["ismaster"]: - return node, True, isdbgrid, res_time + nodes = frozenset([ + _partition_node(h) for h in response["hosts"]]) + + if member.is_primary: + return member, nodes + elif "primary" in response: + # Shortcut: a secondary usually tells us who the primary is. candidate = _partition_node(response["primary"]) return self.__try_node(candidate) @@ -693,29 +718,94 @@ class MongoClient(common.BaseObject): raise AutoReconnect('%s:%d is not primary or master' % node) # Direct connection - if response.get("arbiterOnly", False) and not self.__direct: + if member.is_arbiter and not self.__direct: raise ConfigurationError("%s:%d is an arbiter" % node) - return node, response['ismaster'], isdbgrid, res_time + + return member, nodes def __pick_nearest(self, candidates): - """Return the 'nearest' candidate based on response time. + """Return the 'nearest' Member instance based on response time. + + Doesn't modify state. """ latency = self.secondary_acceptable_latency_ms - # Only used for mongos high availability, res_time is in seconds. - fastest = min([res_time for candidate, res_time in candidates]) + # Only used for mongos high availability, ping_time is in seconds. + fastest = min([ + member.ping_time for member in candidates]) + near_candidates = [ - candidate for candidate, res_time in candidates - if res_time - fastest < latency / 1000.0 - ] + member for member in candidates + if member.ping_time - fastest < latency / 1000.0] - node = random.choice(near_candidates) - # Clear the pool from the last choice. - self.disconnect() - self.__host, self.__port = node - return node + return random.choice(near_candidates) - def __find_node(self, seeds=None): - """Find a host, port pair suitable for our connection type. + def __ensure_member(self): + """Connect and return a Member instance, or raise AutoReconnect.""" + # If `connecting` is False, no thread is in __find_node(), + # and `future_member` is resolved. `member` may be None if the + # last __find_node() attempt failed, otherwise it is in `nodes`. + # + # If `connecting` is True, a thread is in __find_node(), + # `member` is None, and `future_member` is pending. + # + # To violate these invariants temporarily, acquire the lock. + # Note that disconnect() interacts with this method. + self.__connecting_lock.acquire() + if self.__member: + member = self.__member + self.__connecting_lock.release() + return member + + elif self.__connecting: + # A thread is in __find_node(). Wait. + future = self.__future_member + self.__connecting_lock.release() + return future.result() + + else: + self.__connecting = True + future = self.__future_member = thread_util.Future( + self.__event_class) + + self.__connecting_lock.release() + + member = None + nodes = None + exc = None + + try: + try: + member, nodes = self.__find_node() + return member + except Exception, e: + exc = e + raise + finally: + # We're either returning a Member or raising an error. + # Propagate either outcome to waiting threads. + self.__connecting_lock.acquire() + self.__member = member + self.__connecting = False + + # If we discovered a set of nodes, use them from now on; + # otherwise we're raising an error. Stick with the last + # known good set of nodes. + if nodes: + self.__nodes = nodes + + if member: + # Unblock waiting threads. + future.set_result(member) + else: + # Raise exception in waiting threads. + future.set_exception(exc) + + self.__connecting_lock.release() + + def __find_node(self): + """Find a server suitable for our connection type. + + Returns a Member and a set of nodes. Doesn't modify state. If only one host was supplied to __init__ see if we can connect to it. Don't check if the host is a master/primary so we can make @@ -724,75 +814,77 @@ class MongoClient(common.BaseObject): If more than one host was supplied treat them as a seed list for connecting to a replica set or to support high availability for - mongos. If connecting to a replica set try to find the primary - and fail if we can't, possibly updating any replSet information - on success. If a mongos seed list was provided find the "nearest" - mongos and return it. + mongos. If connecting to a replica set try to find the primary, + and set `nodes` to list of all members. + + If a mongos seed list was provided find the "nearest" mongos and + return it, setting `nodes` to all mongoses in the seed list that + are up. Otherwise we iterate through the list trying to find a host we can send write operations to. - - Sets __host and __port so that :attr:`host` and :attr:`port` - will return the address of the connected host. Sets __is_primary to - True if this is a primary or master, else False. Sets __is_mongos - to True if the connection is to a mongos. """ + assert not self.__member, \ + "__find_node unexpectedly running with a non-null Member" + errors = [] mongos_candidates = [] - candidates = seeds or self.__nodes.copy() + candidates = self.__nodes or self.__seeds + chosen_member = None + discovered_nodes = None + for candidate in candidates: try: - node, ismaster, isdbgrid, res_time = self.__try_node(candidate) - self.__is_primary = ismaster - self.__is_mongos = isdbgrid - # No need to calculate nearest if we only have one mongos. - if isdbgrid and not self.__direct: - mongos_candidates.append((node, res_time)) + member, nodes = self.__try_node(candidate) + if member.is_mongos and not self.__direct: + mongos_candidates.append(member) + + # We intend to find all the mongoses; keep trying nodes. continue elif len(mongos_candidates): raise ConfigurationError("Seed list cannot contain a mix " "of mongod and mongos instances.") - return node + + # We've found a suitable node. + chosen_member = member + discovered_nodes = nodes + break except OperationFailure: # The server is available but something failed, probably auth. raise except Exception, why: errors.append(str(why)) - # If we have a mongos seed list, pick the "nearest" member. if len(mongos_candidates): - self.__is_mongos = True - return self.__pick_nearest(mongos_candidates) + # If we have a mongos seed list, pick the "nearest" member. + chosen_member = self.__pick_nearest(mongos_candidates) + mongoses = frozenset(m.host for m in mongos_candidates) - # Otherwise, try any hosts we discovered that were not in the seed list. - for candidate in self.__nodes - candidates: - try: - node, ismaster, isdbgrid, _ = self.__try_node(candidate) - self.__is_primary = ismaster - self.__is_mongos = isdbgrid - return node - except Exception, why: - errors.append(str(why)) - # Couldn't find a suitable host. - self.disconnect() - raise AutoReconnect(', '.join(errors)) + # The first time, __nodes is empty and mongoses becomes nodes. + return chosen_member, self.__nodes or mongoses - def __socket(self): - """Get a SocketInfo from the pool. + if not chosen_member: + # Couldn't find a suitable host. + raise AutoReconnect(', '.join(errors)) + + return chosen_member, discovered_nodes + + def __socket(self, member): + """Get a SocketInfo. + + Calls disconnect() on error. """ - host, port = (self.__host, self.__port) - if host is None or (port is None and '/' not in host): - host, port = self.__find_node() - + connection_pool = member.pool try: - if self.auto_start_request and not self.in_request(): - self.start_request() + if self.auto_start_request and not connection_pool.in_request(): + connection_pool.start_request() - sock_info = self.__pool.get_socket((host, port)) + sock_info = connection_pool.get_socket() except socket.error, why: self.disconnect() # Check if a unix domain socket + host, port = member.host if host.endswith('.sock'): host_details = "%s:" % host else: @@ -802,16 +894,14 @@ class MongoClient(common.BaseObject): try: self.__check_auth(sock_info) except OperationFailure: - self.__pool.maybe_return_socket(sock_info) + connection_pool.maybe_return_socket(sock_info) raise return sock_info - def _ensure_connected(self, dummy): + def _ensure_connected(self, sync=False): """Ensure this client instance is connected to a mongod/s. """ - host, port = (self.__host, self.__port) - if host is None or (port is None and '/' not in host): - self.__find_node() + self.__ensure_member() def disconnect(self): """Disconnect from MongoDB. @@ -825,9 +915,13 @@ class MongoClient(common.BaseObject): .. seealso:: :meth:`end_request` .. versionadded:: 1.3 """ - self.__pool.reset() - self.__host = None - self.__port = None + self.__connecting_lock.acquire() + member, self.__member = self.__member, None + self.__connecting_lock.release() + + # Close sockets promptly. + if member: + member.pool.reset() def close(self): """Alias for :meth:`disconnect` @@ -866,16 +960,20 @@ class MongoClient(common.BaseObject): # calls select() if the socket hasn't been checked in the last second, # or it may create a new socket, in which case calling select() is # redundant. - sock_info = None + member, sock_info = None, None try: try: - sock_info = self.__socket() + member = self.__ensure_member() + if not member: + return False + + sock_info = member.pool.get_socket() return not pool._closed(sock_info.sock) except (socket.error, ConnectionFailure): return False finally: if sock_info: - self.__pool.maybe_return_socket(sock_info) + member.pool.maybe_return_socket(sock_info) def set_cursor_manager(self, manager_class): """Set this client's cursor manager. @@ -950,12 +1048,12 @@ class MongoClient(common.BaseObject): """ if len(message) == 3: (request_id, data, max_doc_size) = message - if max_doc_size > self.__max_bson_size: + if max_doc_size > self.max_bson_size: raise InvalidDocument("BSON document too large (%d bytes)" " - the connected server supports" " BSON document sizes up to %d" " bytes." % - (max_doc_size, self.__max_bson_size)) + (max_doc_size, self.max_bson_size)) return (request_id, data) else: # get_more and kill_cursors messages @@ -983,7 +1081,8 @@ class MongoClient(common.BaseObject): # The write won't succeed, bail as if we'd done a getLastError raise AutoReconnect("not master") - sock_info = self.__socket() + member = self.__ensure_member() + sock_info = self.__socket(member) try: try: (request_id, data) = self.__check_bson_size(message) @@ -1008,7 +1107,7 @@ class MongoClient(common.BaseObject): sock_info.close() raise finally: - self.__pool.maybe_return_socket(sock_info) + member.pool.maybe_return_socket(sock_info) def __receive_data_on_socket(self, length, sock_info): """Lowest level receive operation. @@ -1063,7 +1162,8 @@ class MongoClient(common.BaseObject): :Parameters: - `message`: (request_id, data) pair making up the message to send """ - sock_info = self.__socket() + member = self.__ensure_member() + sock_info = self.__socket(member) exhaust = kwargs.get('exhaust') try: try: @@ -1075,13 +1175,13 @@ class MongoClient(common.BaseObject): if "network_timeout" in kwargs: sock_info.sock.settimeout(self.__net_timeout) - return (None, (response, sock_info, self.__pool)) + return (None, (response, sock_info, member.pool)) except (ConnectionFailure, socket.error), e: self.disconnect() raise AutoReconnect(str(e)) finally: if not exhaust: - self.__pool.maybe_return_socket(sock_info) + member.pool.maybe_return_socket(sock_info) def _exhaust_next(self, sock_info): """Used with exhaust cursors to get the next batch off the socket. @@ -1118,14 +1218,16 @@ class MongoClient(common.BaseObject): The :class:`~pymongo.pool.Request` return value. :meth:`start_request` previously returned None """ - self.__pool.start_request() + member = self.__ensure_member() + member.pool.start_request() return pool.Request(self) def in_request(self): """True if this thread is in a request, meaning it has a socket reserved for its exclusive use. """ - return self.__pool.in_request() + member = self.__member # Don't try to connect if disconnected. + return member and member.pool.in_request() def end_request(self): """Undo :meth:`start_request`. If :meth:`end_request` is called as many @@ -1143,13 +1245,13 @@ class MongoClient(common.BaseObject): in the middle of a sequence of operations in which ordering is important. This could lead to unexpected results. """ - self.__pool.end_request() + member = self.__member # Don't try to connect if disconnected. + if member: + member.pool.end_request() def __eq__(self, other): if isinstance(other, self.__class__): - us = (self.__host, self.__port) - them = (other.__host, other.__port) - return us == them + return self.host == other.host and self.port == other.port return NotImplemented def __ne__(self, other): @@ -1157,7 +1259,7 @@ class MongoClient(common.BaseObject): def __repr__(self): if len(self.__nodes) == 1: - return "MongoClient(%r, %r)" % (self.__host, self.__port) + return "MongoClient(%r, %r)" % (self.host, self.port) else: return "MongoClient(%r)" % ["%s:%d" % n for n in self.__nodes] diff --git a/pymongo/mongo_replica_set_client.py b/pymongo/mongo_replica_set_client.py index 55104db0f..4faf5e4b2 100644 --- a/pymongo/mongo_replica_set_client.py +++ b/pymongo/mongo_replica_set_client.py @@ -49,6 +49,7 @@ from pymongo import (auth, pool, thread_util, uri_parser) +from pymongo.member import Member from pymongo.read_preferences import ( ReadPreference, select_member, modes, MovingAverage) from pymongo.errors import (AutoReconnect, @@ -63,11 +64,6 @@ from pymongo.thread_util import DummyLock EMPTY = b("") MAX_RETRY = 3 -# Member states -PRIMARY = 1 -SECONDARY = 2 -OTHER = 3 - MONITORS = set() def register_monitor(monitor): @@ -390,110 +386,6 @@ except ImportError: pass -class Member(object): - """Immutable representation of one member of a replica set. - - :Parameters: - - `host`: A (host, port) pair - - `connection_pool`: A Pool instance - - `ismaster_response`: A dict, MongoDB's ismaster response - - `ping_time`: A MovingAverage instance - - `up`: Whether we think this member is available - """ - # For unittesting only. Use under no circumstances! - _host_to_ping_time = {} - - def __init__(self, host, connection_pool, ismaster_response, ping_time, up): - self.host = host - self.pool = connection_pool - self.ismaster_response = ismaster_response - self.ping_time = ping_time - self.up = up - - if ismaster_response['ismaster']: - self.state = PRIMARY - elif ismaster_response.get('secondary'): - self.state = SECONDARY - else: - self.state = OTHER - - self.tags = ismaster_response.get('tags', {}) - self.max_bson_size = ismaster_response.get( - 'maxBsonObjectSize', common.MAX_BSON_SIZE) - self.max_message_size = ismaster_response.get( - 'maxMessageSizeBytes', 2 * self.max_bson_size) - self.min_wire_version = ismaster_response.get('minWireVersion', - common.MIN_WIRE_VERSION) - self.max_wire_version = ismaster_response.get('maxWireVersion', - common.MAX_WIRE_VERSION) - - def clone_with(self, ismaster_response, ping_time_sample): - """Get a clone updated with ismaster response and a single ping time. - """ - ping_time = self.ping_time.clone_with(ping_time_sample) - return Member(self.host, self.pool, ismaster_response, ping_time, True) - - def clone_down(self): - """Get a clone of this Member, but with up=False. - """ - return Member( - self.host, self.pool, self.ismaster_response, self.ping_time, - False) - - @property - def is_primary(self): - return self.state == PRIMARY - - @property - def is_secondary(self): - return self.state == SECONDARY - - def get_avg_ping_time(self): - """Get a moving average of this member's ping times. - """ - if self.host in Member._host_to_ping_time: - # Simulate ping times for unittesting - return Member._host_to_ping_time[self.host] - - return self.ping_time.get() - - def matches_mode(self, mode): - if mode == ReadPreference.PRIMARY and not self.is_primary: - return False - - if mode == ReadPreference.SECONDARY and not self.is_secondary: - return False - - # If we're not primary or secondary, then we're in a state like - # RECOVERING and we don't match any mode - return self.is_primary or self.is_secondary - - def matches_tags(self, tags): - """Return True if this member's tags are a superset of the passed-in - tags. E.g., if this member is tagged {'dc': 'ny', 'rack': '1'}, - then it matches {'dc': 'ny'}. - """ - for key, value in tags.items(): - if key not in self.tags or self.tags[key] != value: - return False - - return True - - def matches_tag_sets(self, tag_sets): - """Return True if this member matches any of the tag sets, e.g. - [{'dc': 'ny'}, {'dc': 'la'}, {}] - """ - for tags in tag_sets: - if self.matches_tags(tags): - return True - - return False - - def __str__(self): - return '' % ( - self.host[0], self.host[1], self.is_primary, self.up) - - class MongoReplicaSetClient(common.BaseObject): """Connection to a MongoDB replica set. """ diff --git a/pymongo/thread_util.py b/pymongo/thread_util.py index a74682095..9b69e4bea 100644 --- a/pymongo/thread_util.py +++ b/pymongo/thread_util.py @@ -33,6 +33,7 @@ try: from gevent.coros import BoundedSemaphore as GeventBoundedSemaphore from gevent.greenlet import SpawnedLink + from gevent.event import Event as GeventEvent except ImportError: have_gevent = False @@ -185,6 +186,32 @@ class Counter(object): return self._counters.get(self.ident.get(), 0) +class Future(object): + """Minimal backport of concurrent.futures.Future. + + event_class makes this Future adaptable for Gevent and other frameworks. + """ + def __init__(self, event_class): + self._event = event_class() + self._result = None + self._exception = None + + def set_result(self, result): + self._result = result + self._event.set() + + def set_exception(self, exc): + self._exception = exc + self._event.set() + + def result(self): + self._event.wait() + if self._exception: + raise self._exception + else: + return self._result + + ### Begin backport from CPython 3.2 for timeout support for Semaphore.acquire class Semaphore: @@ -302,3 +329,10 @@ def create_semaphore(max_size, max_waiters, use_greenlets): return BoundedSemaphore(max_size) else: return MaxWaitersBoundedSemaphoreThread(max_size, max_waiters) + + +def create_event(use_greenlets): + if use_greenlets: + return GeventEvent() + else: + return threading.Event() diff --git a/test/high_availability/test_ha.py b/test/high_availability/test_ha.py index 822825a84..39117be18 100644 --- a/test/high_availability/test_ha.py +++ b/test/high_availability/test_ha.py @@ -27,7 +27,8 @@ from ha_tools import use_greenlets from pymongo.errors import AutoReconnect, OperationFailure, ConnectionFailure -from pymongo.mongo_replica_set_client import Member, Monitor +from pymongo.member import Member +from pymongo.mongo_replica_set_client import Monitor from pymongo.mongo_replica_set_client import MongoReplicaSetClient from pymongo.mongo_client import MongoClient, _partition_node from pymongo.read_preferences import ReadPreference, modes diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py new file mode 100644 index 000000000..2bafe5b54 --- /dev/null +++ b/test/pymongo_mocks.py @@ -0,0 +1,146 @@ +# Copyright 2013 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for mocking parts of PyMongo to test other parts.""" + +import socket + +from pymongo import MongoClient, MongoReplicaSetClient +from pymongo.pool import Pool + +from test import host as default_host, port as default_port +from test.utils import my_partial + + +class MockPool(Pool): + def __init__(self, client, pair, *args, **kwargs): + # MockPool gets a 'client' arg, regular pools don't. + self.client = client + self.mock_host, self.mock_port = pair + + # Actually connect to the default server. + Pool.__init__( + self, + pair=(default_host, default_port), + max_size=None, + net_timeout=None, + conn_timeout=20, + use_ssl=False, + use_greenlets=False) + + def get_socket(self, pair=None, force=False): + client = self.client + host_and_port = '%s:%s' % (self.mock_host, self.mock_port) + if host_and_port in client.mock_down_hosts: + raise socket.timeout('mock timeout') + + assert host_and_port in ( + client.mock_standalones + + client.mock_members + + client.mock_mongoses), "bad host: %s" % host_and_port + + sock_info = Pool.get_socket(self, pair, force) + sock_info.mock_host = self.mock_host + sock_info.mock_port = self.mock_port + return sock_info + + +class MockClientBase(object): + def __init__(self, standalones, members, mongoses): + """standalones, etc., are like ['a:1', 'b:2']""" + self.mock_standalones = standalones[:] + self.mock_members = members[:] + + if self.mock_members: + self.mock_primary = self.mock_members[0] + else: + self.mock_primary = None + + self.mock_conf = members[:] + self.mock_mongoses = mongoses[:] + + # Hosts that should raise socket errors. + self.mock_down_hosts = [] + + def kill_host(self, host): + """Host is like 'a:1'.""" + self.mock_down_hosts.append(host) + + def revive_host(self, host): + """Host is like 'a:1'.""" + self.mock_down_hosts.remove(host) + + def mock_is_master(self, host): + # host is like 'a:1'. + if host in self.mock_down_hosts: + raise socket.timeout('mock timeout') + + if host in self.mock_standalones: + return {'ismaster': True} + + if host in self.mock_members: + ismaster = (host == self.mock_primary) + + # Simulate a replica set member. + response = { + 'ismaster': ismaster, + 'secondary': not ismaster, + 'setName': 'rs', + 'hosts': self.mock_conf} + + if self.mock_primary: + response['primary'] = self.mock_primary + + return response + + if host in self.mock_mongoses: + return {'ismaster': True, 'msg': 'isdbgrid'} + + raise AssertionError('Unknown host: %s' % host) + + def simple_command(self, sock_info, dbname, spec): + # __simple_command is also used for authentication, but in this + # test it's only used for ismaster. + assert spec == {'ismaster': 1} + response = self.mock_is_master( + '%s:%s' % (sock_info.mock_host, sock_info.mock_port)) + + ping_time = 10 + return response, ping_time + + +class MockClient(MockClientBase, MongoClient): + def __init__(self, standalones, members, mongoses, *args, **kwargs): + MockClientBase.__init__(self, standalones, members, mongoses) + kwargs['_pool_class'] = my_partial(MockPool, self) + MongoClient.__init__(self, *args, **kwargs) + + def _MongoClient__simple_command(self, sock_info, dbname, spec): + return self.simple_command(sock_info, dbname, spec) + + +class MockReplicaSetClient(MockClientBase, MongoReplicaSetClient): + def __init__(self, standalones, members, mongoses, *args, **kwargs): + MockClientBase.__init__(self, standalones, members, mongoses) + kwargs['_pool_class'] = my_partial(MockPool, self) + MongoReplicaSetClient.__init__(self, *args, **kwargs) + + def _MongoReplicaSetClient__is_master(self, host): + response = self.mock_is_master('%s:%s' % host) + connection_pool = MockPool(self, host) + ping_time = 10 + return response, connection_pool, ping_time + + def _MongoReplicaSetClient__simple_command(self, sock_info, dbname, spec): + return self.simple_command(sock_info, dbname, spec) diff --git a/test/test_client.py b/test/test_client.py index 58b53b570..1a89450f2 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -34,12 +34,14 @@ from pymongo.mongo_client import MongoClient from pymongo.database import Database from pymongo.pool import SocketInfo from pymongo import thread_util -from pymongo.errors import (ConfigurationError, +from pymongo.errors import (AutoReconnect, + ConfigurationError, ConnectionFailure, InvalidName, OperationFailure, PyMongoError) -from test import version, host, port +from test import version, host, port, pair +from test.pymongo_mocks import MockClient from test.utils import (assertRaisesExactly, delay, is_mongos, @@ -47,7 +49,10 @@ from test.utils import (assertRaisesExactly, server_is_master_with_slave, server_started_with_auth, TestRequestMixin, - _TestLazyConnectMixin) + _TestLazyConnectMixin, + lazy_client_trial, + NTHREADS, + get_pool) def get_client(*args, **kwargs): @@ -92,7 +97,7 @@ class TestClient(unittest.TestCase, TestRequestMixin): self.assertIsInstance(c.is_mongos, bool) self.assertIsInstance(c.max_pool_size, int) self.assertIsInstance(c.use_greenlets, bool) - self.assertIsInstance(c.nodes, set) + self.assertIsInstance(c.nodes, frozenset) self.assertIsInstance(c.auto_start_request, bool) self.assertEqual(dict, c.get_document_class()) self.assertIsInstance(c.tz_aware, bool) @@ -510,9 +515,9 @@ class TestClient(unittest.TestCase, TestRequestMixin): def test_timeouts(self): client = MongoClient(host, port, connectTimeoutMS=10500) - self.assertEqual(10.5, client._MongoClient__pool.conn_timeout) + self.assertEqual(10.5, client._MongoClient__member.pool.conn_timeout) client = MongoClient(host, port, socketTimeoutMS=10500) - self.assertEqual(10.5, client._MongoClient__pool.net_timeout) + self.assertEqual(10.5, client._MongoClient__member.pool.net_timeout) def test_network_timeout_validation(self): c = get_client(socketTimeoutMS=10 * 1000) @@ -565,11 +570,11 @@ class TestClient(unittest.TestCase, TestRequestMixin): def test_waitQueueTimeoutMS(self): client = MongoClient(host, port, waitQueueTimeoutMS=2000) - self.assertEqual(client._MongoClient__pool.wait_queue_timeout, 2) + self.assertEqual(client._MongoClient__member.pool.wait_queue_timeout, 2) def test_waitQueueMultiple(self): client = MongoClient(host, port, max_pool_size=3, waitQueueMultiple=2) - pool = client._MongoClient__pool + pool = client._MongoClient__member.pool self.assertEqual(pool.wait_queue_multiple, 2) self.assertEqual(pool._socket_semaphore.waiter_semaphore.counter, 6) @@ -647,26 +652,25 @@ class TestClient(unittest.TestCase, TestRequestMixin): # The socket used for the previous commands has been returned to the # pool - self.assertEqual(1, len(client._MongoClient__pool.sockets)) + self.assertEqual(1, len(client._MongoClient__member.pool.sockets)) # We need exec here because if the Python version is less than 2.6 # these with-statements won't even compile. exec """ with contextlib.closing(client): self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"]) -self.assertEqual(0, len(client._MongoClient__pool.sockets)) +self.assertEqual(None, client._MongoClient__member) """ exec """ with get_client() as client: self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"]) -# Calling client.close() has reset the pool -self.assertEqual(0, len(client._MongoClient__pool.sockets)) +self.assertEqual(None, client._MongoClient__member) """ def test_with_start_request(self): client = get_client() - pool = client._MongoClient__pool + pool = client._MongoClient__member.pool # No request started self.assertNoRequest(pool) @@ -716,12 +720,11 @@ with client.start_request() as request: client = get_client(auto_start_request=True) self.assertTrue(client.auto_start_request) - self.assertTrue(client.in_request()) - pool = client._MongoClient__pool - # Request started already, just from MongoClient constructor - it's a - # bit weird, but MongoClient does some socket stuff when it initializes - # and it ends up with a request socket + # Assure we acquire a request socket. + client.pymongo_test.test.find_one() + self.assertTrue(client.in_request()) + pool = client._MongoClient__member.pool self.assertRequestSocket(pool) self.assertSameSock(pool) @@ -737,7 +740,7 @@ with client.start_request() as request: def test_nested_request(self): # auto_start_request is False client = get_client() - pool = client._MongoClient__pool + pool = client._MongoClient__member.pool self.assertFalse(client.in_request()) # Start and end request @@ -765,7 +768,7 @@ with client.start_request() as request: def test_request_threads(self): client = get_client(auto_start_request=False) - pool = client._MongoClient__pool + pool = client._MongoClient__member.pool self.assertNotInRequestAndDifferentSock(client, pool) started_request, ended_request = threading.Event(), threading.Event() @@ -854,7 +857,7 @@ with client.start_request() as request: # Ensure MongoClient doesn't close socket after it gets an error # response to getLastError. PYTHON-395. c = get_client() - pool = c._MongoClient__pool + pool = c._MongoClient__member.pool self.assertEqual(1, len(pool.sockets)) old_sock_info = iter(pool.sockets).next() c.pymongo_test.test.drop() @@ -871,9 +874,10 @@ with client.start_request() as request: # Ensure MongoClient doesn't close socket after it gets an error # response to getLastError. PYTHON-395. c = get_client(auto_start_request=True) - pool = c._MongoClient__pool + pool = get_pool(c) - # MongoClient has reserved a socket for this thread + # Pool reserves a socket for this thread. + c.pymongo_test.test.find_one() self.assertTrue(isinstance(pool._get_request_state(), SocketInfo)) old_sock_info = pool._get_request_state() @@ -908,5 +912,110 @@ class TestClientLazyConnect(unittest.TestCase, _TestLazyConnectMixin): return get_client(**kwargs) +class TestClientLazyConnectBadSeeds(unittest.TestCase): + def _get_client(self, **kwargs): + kwargs.setdefault('connectTimeoutMS', 100) + + # Assume there are no open mongods listening on a.com, b.com, .... + bad_seeds = ['%s.com' % chr(ord('a') + i) for i in range(10)] + return MongoClient(bad_seeds, **kwargs) + + def test_connect(self): + def reset(dummy): + pass + + def connect(collection, dummy): + self.assertRaises(AutoReconnect, collection.find_one) + + def test(collection): + client = collection.database.connection + self.assertEqual(0, len(client.nodes)) + + lazy_client_trial( + reset, connect, test, + self._get_client, use_greenlets=False) + + +class TestClientLazyConnectOneGoodSeed( + unittest.TestCase, + _TestLazyConnectMixin): + + def _get_client(self, **kwargs): + kwargs.setdefault('connectTimeoutMS', 100) + + # Assume there are no open mongods listening on a.com, b.com, .... + bad_seeds = ['%s.com' % chr(ord('a') + i) for i in range(10)] + seeds = bad_seeds + [pair] + + # MongoClient puts the seeds in a set before iterating, so order is + # undefined. + return MongoClient(seeds, **kwargs) + + def test_insert(self): + def reset(collection): + collection.drop() + + def insert(collection, dummy): + collection.insert({}) + + def test(collection): + self.assertEqual(NTHREADS, collection.count()) + + lazy_client_trial( + reset, insert, test, + self._get_client, use_greenlets=False) + + +class TestMongoClientFailover(unittest.TestCase): + def test_discover_primary(self): + c = MockClient( + standalones=[], + members=['a:1', 'b:2', 'c:3'], + mongoses=[], + host='b:2', # Pass a secondary. + replicaSet='rs') + + self.assertEqual('a', c.host) + self.assertEqual(1, c.port) + self.assertEqual(3, len(c.nodes)) + + # Fail over. + c.kill_host('a:1') + c.mock_primary = 'b:2' + + # Force reconnect. + c.disconnect() + c.db.collection.find_one() + self.assertEqual('b', c.host) + self.assertEqual(2, c.port) + + # a:1 is still in nodes. + self.assertEqual(3, len(c.nodes)) + + def test_reconnect(self): + # Verify the node list isn't forgotten during a network failure. + c = MockClient( + standalones=[], + members=['a:1', 'b:2', 'c:3'], + mongoses=[], + host='b:2', # Pass a secondary. + replicaSet='rs') + + # Total failure. + c.kill_host('a:1') + c.kill_host('b:2') + c.kill_host('c:3') + + # MongoClient discovers it's alone. + self.assertRaises(AutoReconnect, c.db.collection.find_one) + + # But it remembers its node list. + self.assertEqual(3, len(c.nodes)) + + # So it can reconnect. + c.revive_host('a:1') + c.db.collection.find_one() + + if __name__ == "__main__": unittest.main() diff --git a/test/test_collection.py b/test/test_collection.py index 076c16b0e..ceb29e2f2 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -48,7 +48,7 @@ from pymongo.errors import (DuplicateKeyError, OperationFailure, WTimeoutError) from test.test_client import get_client -from test.utils import is_mongos, joinall, enable_text_search +from test.utils import is_mongos, joinall, enable_text_search, get_pool from test import (qcheck, version) @@ -1661,7 +1661,7 @@ class TestCollection(unittest.TestCase): self.db.test.insert([{'i': i} for i in xrange(150)]) client = get_client(max_pool_size=1) - socks = client._MongoClient__pool.sockets + socks = get_pool(client).sockets self.assertEqual(1, len(socks)) # Make sure the socket is returned after exhaustion. diff --git a/test/test_database.py b/test/test_database.py index 8129f3c51..6a7b1def3 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -551,7 +551,6 @@ class TestDatabase(unittest.TestCase): request_cx = get_client(auto_start_request=True) request_db = request_cx.pymongo_test - self.assertTrue(request_cx.in_request()) self.assertTrue(request_db.authenticate("mike", "password")) self.assertTrue(request_cx.in_request()) finally: diff --git a/test/test_legacy_connections.py b/test/test_legacy_connections.py index 0592e74f2..59eb40d69 100644 --- a/test/test_legacy_connections.py +++ b/test/test_legacy_connections.py @@ -52,7 +52,7 @@ class TestConnection(unittest.TestCase): # To preserve legacy Connection's behavior, max_size should be None. # Pool should handle this without error. - self.assertEqual(None, c._MongoClient__pool.max_size) + self.assertEqual(None, get_pool(c).max_size) c.end_request() # Connection's network_timeout argument is translated into diff --git a/test/test_master_slave_connection.py b/test/test_master_slave_connection.py index 961593e6c..9298272f8 100644 --- a/test/test_master_slave_connection.py +++ b/test/test_master_slave_connection.py @@ -35,7 +35,8 @@ from pymongo.mongo_client import MongoClient from pymongo.collection import Collection from pymongo.master_slave_connection import MasterSlaveConnection from test import host, port, host2, port2, host3, port3 -from test.utils import TestRequestMixin +from test.utils import TestRequestMixin, get_pool + class TestMasterSlaveConnection(unittest.TestCase, TestRequestMixin): @@ -275,7 +276,7 @@ class TestMasterSlaveConnection(unittest.TestCase, TestRequestMixin): client = self.client # In a request, all ops go through master - pool = client.master._MongoClient__pool + pool = get_pool(client.master) client.master.end_request() self.assertNotInRequestAndDifferentSock(client, pool) diff --git a/test/test_mongos_ha.py b/test/test_mongos_ha.py new file mode 100644 index 000000000..14e24c871 --- /dev/null +++ b/test/test_mongos_ha.py @@ -0,0 +1,124 @@ +# Copyright 2013 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test MongoClient's mongos high-availability features using a mock.""" + +import sys +import threading +import unittest + +sys.path[0:0] = [""] + +from pymongo.errors import AutoReconnect +from test.pymongo_mocks import MockClient + + +class FindOne(threading.Thread): + def __init__(self, client): + super(FindOne, self).__init__() + self.client = client + self.passed = False + + def run(self): + self.client.db.collection.find_one() + self.passed = True # No exception raised. + + +def do_find_one(client, nthreads): + threads = [FindOne(client) for _ in range(nthreads)] + for t in threads: + t.start() + + for t in threads: + t.join() + + for t in threads: + assert t.passed + + +class TestMongosHA(unittest.TestCase): + def mock_client(self, connect): + return MockClient( + standalones=[], + members=[], + mongoses=['a:1', 'b:2', 'c:3'], + host='a:1,b:2,c:3', + _connect=connect) + + def test_lazy_connect(self): + nthreads = 10 + client = self.mock_client(False) + self.assertEqual(0, len(client.nodes)) + + # Trigger initial connection. + do_find_one(client, nthreads) + self.assertEqual(3, len(client.nodes)) + + def test_reconnect(self): + nthreads = 10 + client = self.mock_client(True) + self.assertEqual(3, len(client.nodes)) + + # Trigger reconnect. + client.disconnect() + do_find_one(client, nthreads) + self.assertEqual(3, len(client.nodes)) + + def test_failover(self): + nthreads = 1 + + # ['1:1', '2:2', '3:3', ...] + mock_hosts = ['%d:%d' % (i, i) for i in range(50)] + client = MockClient( + standalones=[], + members=[], + mongoses=mock_hosts, + host=','.join(mock_hosts)) + + self.assertEqual(len(mock_hosts), len(client.nodes)) + + # Our chosen mongos goes down. + client.kill_host('%s:%s' % (client.host, client.port)) + + # Trigger failover. AutoReconnect should be raised exactly once. + errors = [] + passed = [] + + def f(): + try: + client.db.collection.find_one() + except AutoReconnect: + errors.append(True) + + # Second attempt succeeds. + client.db.collection.find_one() + + passed.append(True) + + threads = [threading.Thread(target=f) for _ in range(nthreads)] + for t in threads: + t.start() + + for t in threads: + t.join() + + self.assertEqual(1, len(errors)) + self.assertEqual(nthreads, len(passed)) + + # Down host is still in list. + self.assertEqual(len(mock_hosts), len(client.nodes)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_pooling.py b/test/test_pooling.py index 487ec16f8..21a18765b 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -27,6 +27,7 @@ from test import host, port from test.test_pooling_base import ( _TestPooling, _TestMaxPoolSize, _TestMaxOpenSockets, _TestPoolSocketSharing, _TestWaitQueueMultiple, one) +from test.utils import get_pool class TestPoolingThreads(_TestPooling, unittest.TestCase): @@ -117,15 +118,15 @@ class TestPoolingThreads(_TestPooling, unittest.TestCase): a.pymongo_test.test.remove() a.pymongo_test.test.insert({'_id':1}) a.pymongo_test.test.find_one() - self.assertEqual(1, len(a._MongoClient__pool.sockets)) - a_sock = one(a._MongoClient__pool.sockets) + self.assertEqual(1, len(get_pool(a).sockets)) + a_sock = one(get_pool(a).sockets) def loop(pipe): c = self.get_client(auto_start_request=False) - self.assertEqual(1,len(c._MongoClient__pool.sockets)) + self.assertEqual(1,len(get_pool(c).sockets)) c.pymongo_test.test.find_one() - self.assertEqual(1,len(c._MongoClient__pool.sockets)) - pipe.send(one(c._MongoClient__pool.sockets).sock.getsockname()) + self.assertEqual(1,len(get_pool(c).sockets)) + pipe.send(one(get_pool(c).sockets).sock.getsockname()) cp1, cc1 = Pipe() cp2, cc2 = Pipe() @@ -155,7 +156,7 @@ class TestPoolingThreads(_TestPooling, unittest.TestCase): self.assertTrue(b_sock != c_sock) # a_sock, created by parent process, is still in the pool - d_sock = a._MongoClient__pool.get_socket((a.host, a.port)) + d_sock = get_pool(a).get_socket() self.assertEqual(a_sock, d_sock) d_sock.close() diff --git a/test/test_pooling_base.py b/test/test_pooling_base.py index f5d9682f2..b2805d1ee 100644 --- a/test/test_pooling_base.py +++ b/test/test_pooling_base.py @@ -34,7 +34,7 @@ from pymongo.errors import ConfigurationError, ConnectionFailure from pymongo.errors import ExceededMaxWaiters from test import version, host, port from test.test_client import get_client -from test.utils import delay, is_mongos, one +from test.utils import delay, is_mongos, one, get_pool N = 50 DB = "pymongo-pooling-tests" @@ -190,7 +190,7 @@ class OneOp(MongoThread): super(OneOp, self).__init__(ut) def run_mongo_thread(self): - pool = self.client._MongoClient__pool + pool = get_pool(self.client) assert len(pool.sockets) == 1, "Expected 1 socket, found %d" % ( len(pool.sockets) ) @@ -358,24 +358,31 @@ class _TestPoolingBase(object): time.sleep(seconds) def assert_no_request(self): - self.assertEqual( - NO_REQUEST, self.c._MongoClient__pool._get_request_state() + self.assertTrue( + self.c._MongoClient__member is None or + NO_REQUEST == get_pool(self.c)._get_request_state() ) def assert_request_without_socket(self): self.assertEqual( - NO_SOCKET_YET, self.c._MongoClient__pool._get_request_state() + NO_SOCKET_YET, get_pool(self.c)._get_request_state() ) def assert_request_with_socket(self): self.assertTrue(isinstance( - self.c._MongoClient__pool._get_request_state(), SocketInfo + get_pool(self.c)._get_request_state(), SocketInfo )) def assert_pool_size(self, pool_size): - self.assertEqual( - pool_size, len(self.c._MongoClient__pool.sockets) - ) + if pool_size == 0: + self.assertTrue( + self.c._MongoClient__member is None + or not get_pool(self.c).sockets + ) + else: + self.assertEqual( + pool_size, len(get_pool(self.c).sockets) + ) class _TestPooling(_TestPoolingBase): @@ -462,35 +469,35 @@ class _TestPooling(_TestPoolingBase): def test_multiple_connections(self): a = self.get_client(auto_start_request=False) b = self.get_client(auto_start_request=False) - self.assertEqual(1, len(a._MongoClient__pool.sockets)) - self.assertEqual(1, len(b._MongoClient__pool.sockets)) + self.assertEqual(1, len(get_pool(a).sockets)) + self.assertEqual(1, len(get_pool(b).sockets)) a.start_request() a.pymongo_test.test.find_one() - self.assertEqual(0, len(a._MongoClient__pool.sockets)) + self.assertEqual(0, len(get_pool(a).sockets)) a.end_request() - self.assertEqual(1, len(a._MongoClient__pool.sockets)) - self.assertEqual(1, len(b._MongoClient__pool.sockets)) - a_sock = one(a._MongoClient__pool.sockets) + self.assertEqual(1, len(get_pool(a).sockets)) + self.assertEqual(1, len(get_pool(b).sockets)) + a_sock = one(get_pool(a).sockets) b.end_request() - self.assertEqual(1, len(a._MongoClient__pool.sockets)) - self.assertEqual(1, len(b._MongoClient__pool.sockets)) + self.assertEqual(1, len(get_pool(a).sockets)) + self.assertEqual(1, len(get_pool(b).sockets)) b.start_request() b.pymongo_test.test.find_one() - self.assertEqual(1, len(a._MongoClient__pool.sockets)) - self.assertEqual(0, len(b._MongoClient__pool.sockets)) + self.assertEqual(1, len(get_pool(a).sockets)) + self.assertEqual(0, len(get_pool(b).sockets)) b.end_request() - b_sock = one(b._MongoClient__pool.sockets) + b_sock = one(get_pool(b).sockets) b.pymongo_test.test.find_one() a.pymongo_test.test.find_one() self.assertEqual(b_sock, - b._MongoClient__pool.get_socket((b.host, b.port))) + get_pool(b).get_socket()) self.assertEqual(a_sock, - a._MongoClient__pool.get_socket((a.host, a.port))) + get_pool(a).get_socket()) a_sock.close() b_sock.close() @@ -823,7 +830,7 @@ class _TestMaxPoolSize(_TestPoolingBase): # Socket-reclamation doesn't work in Jython if not sys.platform.startswith('java'): - cx_pool = c._MongoClient__pool + cx_pool = get_pool(c) # Socket-reclamation depends on timely garbage-collection if 'PyPy' in sys.version: @@ -923,7 +930,7 @@ class _TestMaxPoolSize(_TestPoolingBase): for t in threads: self.assertTrue(t.passed) - cx_pool = c._MongoClient__pool + cx_pool = get_pool(c) # Socket-reclamation depends on timely garbage-collection if 'PyPy' in sys.version: diff --git a/test/test_replica_set_client.py b/test/test_replica_set_client.py index 08dcfc8b2..d4ad87940 100644 --- a/test/test_replica_set_client.py +++ b/test/test_replica_set_client.py @@ -35,8 +35,8 @@ from bson.son import SON from bson.tz_util import utc from pymongo.mongo_client import MongoClient from pymongo.read_preferences import ReadPreference +from pymongo.member import PRIMARY, SECONDARY, OTHER from pymongo.mongo_replica_set_client import MongoReplicaSetClient -from pymongo.mongo_replica_set_client import PRIMARY, SECONDARY, OTHER from pymongo.mongo_replica_set_client import _partition_node, have_gevent from pymongo.database import Database from pymongo.pool import SocketInfo diff --git a/test/test_replica_set_reconfig.py b/test/test_replica_set_reconfig.py index 1575ce161..2f8718a2d 100644 --- a/test/test_replica_set_reconfig.py +++ b/test/test_replica_set_reconfig.py @@ -12,118 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test MongoReplicaSetClients and replica set configuration changes.""" +"""Test clients and replica set configuration changes, using mocks.""" -import socket import sys import unittest sys.path[0:0] = [""] -from pymongo import MongoClient from pymongo.errors import AutoReconnect -from pymongo.mongo_replica_set_client import MongoReplicaSetClient -from pymongo.pool import Pool -from test import host as default_host, port as default_port - - -class MockPool(Pool): - def __init__(self, pair, *args, **kwargs): - if pair: - # RS client passes 'pair' to Pool's constructor. - self.mock_host, self.mock_port = pair - else: - # MongoClient passes pair to get_socket() instead. - self.mock_host, self.mock_port = None, None - - Pool.__init__( - self, - pair=(default_host, default_port), - max_size=None, - net_timeout=None, - conn_timeout=20, - use_ssl=False, - use_greenlets=False) - - def get_socket(self, pair=None, force=False): - sock_info = Pool.get_socket(self, (default_host, default_port), force) - sock_info.host = self.mock_host or pair[0] - return sock_info - - -MOCK_HOSTS = ['a:27017', 'b:27017', 'c:27017'] -MOCK_PRIMARY = MOCK_HOSTS[0] -MOCK_RS_NAME = 'rs' - - -class MockClientBase(object): - def __init__(self): - self.mock_hosts = MOCK_HOSTS[:] - - # Hosts that should raise socket errors. - self.mock_down_hosts = [] - - # Hosts that should respond to ismaster as if they're standalone. - self.mock_standalone_hosts = [] - - def mock_is_master(self, host): - if host in self.mock_down_hosts: - raise socket.timeout('timed out') - - if host in self.mock_standalone_hosts: - return {'ismaster': True} - - if host not in self.mock_hosts: - # Host removed from set by a reconfig. - return {'ismaster': False, 'secondary': False} - - ismaster = host == MOCK_PRIMARY - - # Simulate a replica set member. - return { - 'ismaster': ismaster, - 'secondary': not ismaster, - 'setName': MOCK_RS_NAME, - 'hosts': self.mock_hosts} - - def simple_command(self, sock_info, dbname, spec): - # __simple_command is also used for authentication, but in this - # test it's only used for ismaster. - assert spec == {'ismaster': 1} - response = self.mock_is_master('%s:%s' % (sock_info.host, 27017)) - ping_time = 10 - return response, ping_time - - -class MockClient(MockClientBase, MongoClient): - def __init__(self, hosts): - MockClientBase.__init__(self) - MongoClient.__init__( - self, - hosts, - replicaSet=MOCK_RS_NAME, - _pool_class=MockPool) - - def _MongoClient__simple_command(self, sock_info, dbname, spec): - return self.simple_command(sock_info, dbname, spec) - - -class MockReplicaSetClient(MockClientBase, MongoReplicaSetClient): - def __init__(self, hosts): - MockClientBase.__init__(self) - MongoReplicaSetClient.__init__( - self, - hosts, - replicaSet=MOCK_RS_NAME) - - def _MongoReplicaSetClient__is_master(self, host): - response = self.mock_is_master('%s:%s' % host) - connection_pool = MockPool(host) - ping_time = 10 - return response, connection_pool, ping_time - - def _MongoReplicaSetClient__simple_command(self, sock_info, dbname, spec): - return self.simple_command(sock_info, dbname, spec) +from test.pymongo_mocks import MockClient, MockReplicaSetClient class TestSecondaryBecomesStandalone(unittest.TestCase): @@ -131,17 +28,24 @@ class TestSecondaryBecomesStandalone(unittest.TestCase): # brings it back up as standalone, without updating the other # members' config. Verify we don't continue using it. def test_client(self): - c = MockClient(','.join(MOCK_HOSTS)) + c = MockClient( + standalones=[], + members=['a:1', 'b:2', 'c:3'], + mongoses=[], + host='a:1,b:2,c:3', + replicaSet='rs') # MongoClient connects to primary by default. self.assertEqual('a', c.host) - self.assertEqual(27017, c.port) + self.assertEqual(1, c.port) # C is brought up as a standalone. - c.mock_standalone_hosts.append('c:27017') + c.mock_members.remove('c:3') + c.mock_standalones.append('c:3') # Fail over. - c.mock_down_hosts = ['a:27017', 'b:27017'] + c.kill_host('a:1') + c.kill_host('b:2') # Force reconnect. c.disconnect() @@ -157,30 +61,71 @@ class TestSecondaryBecomesStandalone(unittest.TestCase): self.assertEqual(None, c.port) def test_replica_set_client(self): - c = MockReplicaSetClient(','.join(MOCK_HOSTS)) - self.assertTrue(('c', 27017) in c.secondaries) + c = MockReplicaSetClient( + standalones=[], + members=['a:1', 'b:2', 'c:3'], + mongoses=[], + host='a:1,b:2,c:3', + replicaSet='rs') + + self.assertTrue(('b', 2) in c.secondaries) + self.assertTrue(('c', 3) in c.secondaries) # C is brought up as a standalone. - c.mock_standalone_hosts.append('c:27017') + c.mock_members.remove('c:3') + c.mock_standalones.append('c:3') c.refresh() - self.assertEqual(('a', 27017), c.primary) - self.assertEqual(set([('b', 27017)]), c.secondaries) + self.assertEqual(('a', 1), c.primary) + self.assertEqual(set([('b', 2)]), c.secondaries) class TestSecondaryRemoved(unittest.TestCase): # An administrator removes a secondary from a 3-node set *without* # restarting it as standalone. def test_replica_set_client(self): - c = MockReplicaSetClient(','.join(MOCK_HOSTS)) - self.assertTrue(('c', 27017) in c.secondaries) + c = MockReplicaSetClient( + standalones=[], + members=['a:1', 'b:2', 'c:3'], + mongoses=[], + host='a:1,b:2,c:3', + replicaSet='rs') + + self.assertTrue(('b', 2) in c.secondaries) + self.assertTrue(('c', 3) in c.secondaries) # C is removed. - c.mock_hosts.remove('c:27017') + c.mock_conf.remove('c:3') c.refresh() - self.assertEqual(('a', 27017), c.primary) - self.assertEqual(set([('b', 27017)]), c.secondaries) + self.assertEqual(('a', 1), c.primary) + self.assertEqual(set([('b', 2)]), c.secondaries) + + +class TestSecondaryAdded(unittest.TestCase): + def test_client(self): + c = MockClient( + standalones=[], + members=['a:1', 'b:2'], + mongoses=[], + host='a:1', + replicaSet='rs') + + # MongoClient connects to primary by default. + self.assertEqual('a', c.host) + self.assertEqual(1, c.port) + self.assertEqual(set([('a', 1), ('b', 2)]), c.nodes) + + # C is added. + c.mock_members.append('c:3') + c.mock_conf.append('c:3') + + c.disconnect() + c.db.collection.find_one() + + self.assertEqual('a', c.host) + self.assertEqual(1, c.port) + self.assertEqual(set([('a', 1), ('b', 2), ('c', 3)]), c.nodes) if __name__ == "__main__": diff --git a/test/test_threads.py b/test/test_threads.py index 82443d6ba..6a6a7a1b1 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -140,8 +140,8 @@ class FindPauseFind(RendezvousThread): # acquire a socket list(self.collection.find()) - self.pool = get_pool(self.collection.database.connection) - socket_info = self.pool._get_request_state() + pool = get_pool(self.collection.database.connection) + socket_info = pool._get_request_state() assert isinstance(socket_info, SocketInfo) self.request_sock = socket_info.sock assert not _closed(self.request_sock) @@ -151,12 +151,12 @@ class FindPauseFind(RendezvousThread): # because it's not our request socket anymore assert _closed(self.request_sock) - # if disconnect() properly closed all threads' request sockets, then - # this won't raise AutoReconnect because it will acquire a new socket - assert self.request_sock == self.pool._get_request_state().sock + # if disconnect() properly replaced the pool, then this won't raise + # AutoReconnect because it will acquire a new socket list(self.collection.find()) assert self.collection.database.connection.in_request() - assert self.request_sock != self.pool._get_request_state().sock + pool = get_pool(self.collection.database.connection) + assert self.request_sock != pool._get_request_state().sock class BaseTestThreads(object): diff --git a/test/utils.py b/test/utils.py index 84dfa676a..adf2486c5 100644 --- a/test/utils.py +++ b/test/utils.py @@ -298,7 +298,7 @@ def assertReadFromAll(testcase, rsc, members, *args, **kwargs): def get_pool(client): if isinstance(client, MongoClient): - return client._MongoClient__pool + return client._MongoClient__member.pool elif isinstance(client, MongoReplicaSetClient): rs_state = client._MongoReplicaSetClient__rs_state return rs_state.primary_member.pool @@ -316,27 +316,16 @@ class TestRequestMixin(object): """Inherit from this class and from unittest.TestCase to get some convenient methods for testing connection pools and requests """ - def get_sock(self, pool): - # MongoClient calls Pool.get_socket((host, port)), whereas RSC sets - # Pool.pair at construction-time and just calls Pool.get_socket(). - # Deal with either case so we can use TestRequestMixin to test pools - # from MongoClient and from RSC. - if not pool.pair: - sock_info = pool.get_socket((host, port)) - else: - sock_info = pool.get_socket() - return sock_info - def assertSameSock(self, pool): - sock_info0 = self.get_sock(pool) - sock_info1 = self.get_sock(pool) + sock_info0 = pool.get_socket() + sock_info1 = pool.get_socket() self.assertEqual(sock_info0, sock_info1) pool.maybe_return_socket(sock_info0) pool.maybe_return_socket(sock_info1) def assertDifferentSock(self, pool): - sock_info0 = self.get_sock(pool) - sock_info1 = self.get_sock(pool) + sock_info0 = pool.get_socket() + sock_info1 = pool.get_socket() self.assertNotEqual(sock_info0, sock_info1) pool.maybe_return_socket(sock_info0) pool.maybe_return_socket(sock_info1) @@ -367,6 +356,80 @@ class TestRequestMixin(object): self.assertDifferentSock(pool) +# Constants for run_threads and _TestLazyConnectMixin. +NTRIALS = 10 +NTHREADS = 10 + + +def run_threads(collection, target, use_greenlets): + """Run a target function in many threads. + + target is a function taking a Collection and an integer. + """ + threads = [] + for i in range(NTHREADS): + bound_target = my_partial(target, collection, i) + if use_greenlets: + threads.append(gevent.Greenlet(run=bound_target)) + else: + threads.append(threading.Thread(target=bound_target)) + + for t in threads: + t.start() + + for t in threads: + t.join(30) + if use_greenlets: + # bool(Greenlet) is True if it's alive. + assert not t + else: + assert not t.isAlive() + + +def lazy_client_trial(reset, target, test, get_client, use_greenlets): + """Test concurrent operations on a lazily-connecting client. + + `reset` takes a collection and resets it for the next trial. + + `target` takes a lazily-connecting collection and an index from + 0 to NTHREADS, and performs some operation, e.g. an insert. + + `test` takes the lazily-connecting collection and asserts a + post-condition to prove `target` succeeded. + """ + if use_greenlets and not has_gevent: + raise SkipTest('Gevent not installed') + + collection = MongoClient(host, port).pymongo_test.test + + # Make concurrency bugs more likely to manifest. + interval = None + if not sys.platform.startswith('java'): + if sys.version_info >= (3, 2): + interval = sys.getswitchinterval() + sys.setswitchinterval(1e-6) + else: + interval = sys.getcheckinterval() + sys.setcheckinterval(1) + + try: + for i in range(NTRIALS): + reset(collection) + lazy_client = get_client( + _connect=False, use_greenlets=use_greenlets) + + lazy_collection = lazy_client.pymongo_test.test + run_threads(lazy_collection, target, use_greenlets) + test(lazy_collection) + + finally: + if not sys.platform.startswith('java'): + if sys.version_info >= (3, 2): + sys.setswitchinterval(interval) + else: + sys.setcheckinterval(interval) + + class _TestLazyConnectMixin(object): """Test concurrent operations on a lazily-connecting client. @@ -377,75 +440,9 @@ class _TestLazyConnectMixin(object): Set use_greenlets = True to test with Gevent. """ use_greenlets = False - ntrials = 10 - nthreads = 10 - interval = None - def run_threads(self, collection, target): - """Run a target function in many threads. - - target is a function taking a Collection and an integer. - """ - threads = [] - for i in range(self.nthreads): - bound_target = my_partial(target, collection, i) - if self.use_greenlets: - threads.append(gevent.Greenlet(run=bound_target)) - else: - threads.append(threading.Thread(target=bound_target)) - - for t in threads: - t.start() - - for t in threads: - t.join(60) - if self.use_greenlets: - # bool(Greenlet) is True if it's alive. - assert not t - else: - assert not t.isAlive() - - def trial(self, reset, target, test): - """Test concurrent operations on a lazily-connecting client. - - `reset` takes a collection and resets it for the next trial. - - `target` takes a lazily-connecting collection and an index from - 0 to nthreads, and performs some operation, e.g. an insert. - - `test` takes a collection and asserts a post-condition to prove - `target` succeeded. - """ - if self.use_greenlets and not has_gevent: - raise SkipTest('Gevent not installed') - - collection = self._get_client().pymongo_test.test - - # Make concurrency bugs more likely to manifest. - if not sys.platform.startswith('java'): - if sys.version_info >= (3, 2): - self.interval = sys.getswitchinterval() - sys.setswitchinterval(1e-6) - else: - self.interval = sys.getcheckinterval() - sys.setcheckinterval(1) - - try: - for i in range(self.ntrials): - reset(collection) - lazy_client = self._get_client( - _connect=False, use_greenlets=self.use_greenlets) - - lazy_collection = lazy_client.pymongo_test.test - self.run_threads(lazy_collection, target) - test(collection) - - finally: - if not sys.platform.startswith('java'): - if sys.version_info >= (3, 2): - sys.setswitchinterval(self.interval) - else: - sys.setcheckinterval(self.interval) + NTRIALS = 10 + NTHREADS = 10 def test_insert(self): def reset(collection): @@ -455,9 +452,11 @@ class _TestLazyConnectMixin(object): collection.insert({}) def test(collection): - self.assertEqual(self.nthreads, collection.count()) + self.assertEqual(NTHREADS, collection.count()) - self.trial(reset, insert, test) + lazy_client_trial( + reset, insert, test, + self._get_client, self.use_greenlets) def test_save(self): def reset(collection): @@ -467,9 +466,11 @@ class _TestLazyConnectMixin(object): collection.save({}) def test(collection): - self.assertEqual(self.nthreads, collection.count()) + self.assertEqual(NTHREADS, collection.count()) - self.trial(reset, save, test) + lazy_client_trial( + reset, save, test, + self._get_client, self.use_greenlets) def test_update(self): def reset(collection): @@ -481,14 +482,16 @@ class _TestLazyConnectMixin(object): collection.update({}, {'$inc': {'i': 1}}) def test(collection): - self.assertEqual(self.nthreads, collection.find_one()['i']) + self.assertEqual(NTHREADS, collection.find_one()['i']) - self.trial(reset, update, test) + lazy_client_trial( + reset, update, test, + self._get_client, self.use_greenlets) def test_remove(self): def reset(collection): collection.drop() - collection.insert([{'i': i} for i in range(self.nthreads)]) + collection.insert([{'i': i} for i in range(NTHREADS)]) def remove(collection, i): collection.remove({'i': i}) @@ -496,7 +499,9 @@ class _TestLazyConnectMixin(object): def test(collection): self.assertEqual(0, collection.count()) - self.trial(reset, remove, test) + lazy_client_trial( + reset, remove, test, + self._get_client, self.use_greenlets) def test_find_one(self): results = [] @@ -510,9 +515,11 @@ class _TestLazyConnectMixin(object): results.append(collection.find_one()) def test(collection): - self.assertEqual(self.nthreads, len(results)) + self.assertEqual(NTHREADS, len(results)) - self.trial(reset, find_one, test) + lazy_client_trial( + reset, find_one, test, + self._get_client, self.use_greenlets) def test_max_bson_size(self): # Client should have sane defaults before connecting, and should update