PYTHON-525 Reimplement MongoClient to use Cluster.

Replace MongoClient with an implementation that relies on Cluster and Server. The new MongoClient takes over MongoReplicaSetClient's responsibilities.

Authentication, secondary-pinning, and Mongos high-availability are broken and will be reimplemented in a future commit. RS tests are temporarily disabled.
This commit is contained in:
A. Jesse Jiryu Davis 2014-08-13 15:35:17 -04:00
parent 564d20aa76
commit 2f9d24ade6
40 changed files with 1301 additions and 1471 deletions

View File

@ -22,7 +22,6 @@
.. autoattribute:: is_mongos
.. autoattribute:: max_pool_size
.. autoattribute:: nodes
.. autoattribute:: auto_start_request
.. autoattribute:: document_class
.. autoattribute:: tz_aware
.. autoattribute:: max_bson_size

View File

@ -25,8 +25,35 @@ see :doc:`PyMongo's Gevent documentation <examples/gevent>`.
:class:`~pymongo.MongoClient` Changes
.....................................
:class:`~pymongo.mongo_client.MongoClient` is now the one and only
client class for a standalone server, mongos, or replica set.
It includes the functionality that had been split into
``MongoReplicaSetClient``: it can connect to a replica set, discover all its
members, and monitor the set for stepdowns, elections, and reconfigs.
The obsolete ``MasterSlaveConnection`` class is removed.
The :class:`~pymongo.mongo_client.MongoClient` constructor no
longer blocks while connecting to the server or servers, and it no
longer raises :class:`~pymongo.errors.ConnectionFailure` if they
are unavailable, nor :class:`~pymongo.errors.ConfigurationError`
if the user's credentials are wrong. Instead, the constructor
returns immediately and launches the connection process on
background threads.
The ``connect`` option is added and ``auto_start_request`` is removed.
In PyMongo 2.x, :class:`~pymongo.MongoClient` accepted a list of standalone
MongoDB servers and used the first it could connect to::
MongoClient(['host1.com:27017', 'host2.com:27017'])
A list of multiple standalones is no longer supported; if multiple servers
are listed they must be members of the same replica set, or mongoses in the
same sharded cluster.
The second parameter to :meth:`~pymongo.MongoClient.close_cursor` is renamed
from ``_conn_id`` to ``address`` and is no longer optional.
from ``_conn_id`` to ``address``.
:meth:`~pymongo.MongoClient.set_cursor_manager` is no longer deprecated.

View File

@ -25,8 +25,7 @@ from gridfs.errors import (NoFile,
from gridfs.grid_file import (GridIn,
GridOut,
GridOutCursor)
from pymongo import (MongoClient,
ASCENDING,
from pymongo import (ASCENDING,
DESCENDING)
from pymongo.database import Database
@ -34,18 +33,21 @@ from pymongo.database import Database
class GridFS(object):
"""An instance of GridFS on top of a single Database.
"""
def __init__(self, database, collection="fs", _connect=True):
def __init__(self, database, collection="fs"):
"""Create a new instance of :class:`GridFS`.
Raises :class:`TypeError` if `database` is not an instance of
:class:`~pymongo.database.Database`.
The `connect` parameter ensures that the underlying
:class:`~pymongo.mongo_client.MongoClient` is connected to a server,
and creates an index on the "chunks" collection if needed.
:Parameters:
- `database`: database to use
- `collection` (optional): root collection to use
.. versionadded:: 1.6
The `collection` parameter.
- `connect` (optional): whether to begin connecting the client in
the background
.. mongodoc:: gridfs
"""
@ -56,15 +58,10 @@ class GridFS(object):
self.__collection = database[collection]
self.__files = self.__collection.files
self.__chunks = self.__collection.chunks
if _connect:
self.__ensure_index_files_id()
def __is_secondary(self):
client = self.__database.connection
# Connect the client, so we know if it's connected to the primary.
client._ensure_connected()
return isinstance(client, MongoClient) and not client.is_primary
return not client._is_writable()
def __ensure_index_files_id(self):
if not self.__is_secondary():
@ -156,7 +153,11 @@ class GridFS(object):
.. versionadded:: 1.6
"""
return GridOut(self.__collection, file_id)
gout = GridOut(self.__collection, file_id)
# Raise NoFile now, instead of on first attribute access.
gout._ensure_file()
return gout
def get_version(self, filename=None, version=-1, **kwargs):
"""Get a file from GridFS by ``"filename"`` or metadata fields.

View File

@ -49,10 +49,9 @@ NEWLN = b"\n"
DEFAULT_CHUNK_SIZE = 255 * 1024
def _create_property(field_name, docstring,
read_only=False, closed_only=False):
"""Helper for creating properties to read/write to files.
"""
def _grid_in_property(field_name, docstring, read_only=False,
closed_only=False):
"""Create a GridIn property."""
def getter(self):
if closed_only and not self._closed:
raise AttributeError("can only get %r on a closed file" %
@ -70,7 +69,7 @@ def _create_property(field_name, docstring,
self._file[field_name] = value
if read_only:
docstring = docstring + "\n\nThis attribute is read-only."
docstring += "\n\nThis attribute is read-only."
elif closed_only:
docstring = "%s\n\n%s" % (docstring, "This attribute is read-only and "
"can only be read after :meth:`close` "
@ -81,6 +80,20 @@ def _create_property(field_name, docstring,
return property(getter, doc=docstring)
def _grid_out_property(field_name, docstring):
"""Create a GridOut property."""
def getter(self):
self._ensure_file()
# Protect against PHP-237
if field_name == 'length':
return self._file.get(field_name, 0)
return self._file.get(field_name, None)
docstring += "\n\nThis attribute is read-only."
return property(getter, doc=docstring)
class GridIn(object):
"""Class to write data to GridFS.
"""
@ -177,19 +190,19 @@ class GridIn(object):
"""
return self._closed
_id = _create_property("_id", "The ``'_id'`` value for this file.",
_id = _grid_in_property("_id", "The ``'_id'`` value for this file.",
read_only=True)
filename = _create_property("filename", "Name of this file.")
name = _create_property("filename", "Alias for `filename`.")
content_type = _create_property("contentType", "Mime-type for this file.")
length = _create_property("length", "Length (in bytes) of this file.",
filename = _grid_in_property("filename", "Name of this file.")
name = _grid_in_property("filename", "Alias for `filename`.")
content_type = _grid_in_property("contentType", "Mime-type for this file.")
length = _grid_in_property("length", "Length (in bytes) of this file.",
closed_only=True)
chunk_size = _create_property("chunkSize", "Chunk size for this file.",
chunk_size = _grid_in_property("chunkSize", "Chunk size for this file.",
read_only=True)
upload_date = _create_property("uploadDate",
upload_date = _grid_in_property("uploadDate",
"Date that this file was uploaded.",
closed_only=True)
md5 = _create_property("md5", "MD5 of the contents of this file "
md5 = _grid_in_property("md5", "MD5 of the contents of this file "
"(generated on the server).",
closed_only=True)
@ -370,8 +383,7 @@ class GridIn(object):
class GridOut(object):
"""Class to read data out of GridFS.
"""
def __init__(self, root_collection, file_id=None, file_document=None,
_connect=True):
def __init__(self, root_collection, file_id=None, file_document=None):
"""Read a file from GridFS
Application developers should generally not need to
@ -385,11 +397,13 @@ class GridOut(object):
:Parameters:
- `root_collection`: root collection to read from
- `file_id`: value of ``"_id"`` for the file to read
- `file_document`: file document from `root_collection.files`
- `file_id` (optional): value of ``"_id"`` for the file to read
- `file_document` (optional): file document from
`root_collection.files`
.. versionadded:: 1.9
The `file_document` parameter.
.. versionchanged:: 3.0
Creating a GridOut does not immediately retrieve the file metadata
from the server. Metadata is fetched when first needed.
"""
if not isinstance(root_collection, Collection):
raise TypeError("root_collection must be an "
@ -401,27 +415,19 @@ class GridOut(object):
self.__buffer = EMPTY
self.__position = 0
self._file = file_document
if _connect:
self._ensure_file()
_id = _create_property("_id", "The ``'_id'`` value for this file.", True)
filename = _create_property("filename", "Name of this file.", True)
name = _create_property("filename", "Alias for `filename`.", True)
content_type = _create_property("contentType", "Mime-type for this file.",
True)
length = _create_property("length", "Length (in bytes) of this file.",
True)
chunk_size = _create_property("chunkSize", "Chunk size for this file.",
True)
upload_date = _create_property("uploadDate",
"Date that this file was first uploaded.",
True)
aliases = _create_property("aliases", "List of aliases for this file.",
True)
metadata = _create_property("metadata", "Metadata attached to this file.",
True)
md5 = _create_property("md5", "MD5 of the contents of this file "
"(generated on the server).", True)
_id = _grid_out_property("_id", "The ``'_id'`` value for this file.")
filename = _grid_out_property("filename", "Name of this file.")
name = _grid_out_property("filename", "Alias for `filename`.")
content_type = _grid_out_property("contentType", "Mime-type for this file.")
length = _grid_out_property("length", "Length (in bytes) of this file.")
chunk_size = _grid_out_property("chunkSize", "Chunk size for this file.")
upload_date = _grid_out_property("uploadDate",
"Date that this file was first uploaded.")
aliases = _grid_out_property("aliases", "List of aliases for this file.")
metadata = _grid_out_property("metadata", "Metadata attached to this file.")
md5 = _grid_out_property("md5", "MD5 of the contents of this file "
"(generated on the server).")
def _ensure_file(self):
if not self._file:

View File

@ -423,7 +423,6 @@ class _Bulk(object):
'only be executed once.')
self.executed = True
client = self.collection.database.connection
client._ensure_connected(sync=True)
write_concern = write_concern or self.collection.write_concern
if self.ordered:
@ -433,7 +432,7 @@ class _Bulk(object):
if write_concern.get('w') == 0:
self.execute_no_results(generator)
elif client.max_wire_version > 1:
elif client._writable_max_wire_version() > 1:
return self.execute_command(generator, write_concern)
else:
return self.execute_legacy(generator, write_concern)

View File

@ -24,6 +24,8 @@ from pymongo.cluster_description import (updated_cluster_description,
ClusterDescription)
from pymongo.errors import AutoReconnect
from pymongo.server import Server
from pymongo.server_selectors import (secondary_server_selector,
writable_server_selector)
class Cluster(object):
@ -49,35 +51,30 @@ class Cluster(object):
with self._lock:
self._ensure_opened()
def select_servers(self, selector, server_wait_time=None):
def select_servers(self, selector,
server_wait_time=common.SERVER_WAIT_TIME):
"""Return a list of Servers matching selector, or time out.
:Parameters:
- `selector`: function that takes a list of Servers and returns
a subset of them.
- `server_wait_time` (optional): maximum seconds to wait. If not
provided, the initial ClusterSettings' value is used.
provided, the default value common.SERVER_WAIT_TIME is used.
Raises exc:`AutoReconnect` after `server_wait_time` if no
matching servers are found.
"""
if server_wait_time is not None:
wait_time = server_wait_time
else:
wait_time = self._settings.server_wait_time
with self._lock:
self._description.check_compatible()
# TODO: use settings.server_wait_time.
# TODO: use monotonic time if available.
now = time.time()
end_time = now + wait_time
end_time = now + server_wait_time
server_descriptions = self._apply_selector(selector)
while not server_descriptions:
# No suitable servers.
if wait_time == 0 or now > end_time:
if server_wait_time == 0 or now > end_time:
# TODO: more error diagnostics. E.g., if state is
# ReplicaSet but every server is Unknown, and the host list
# is non-empty, and doesn't intersect with settings.seeds,
@ -96,13 +93,14 @@ class Cluster(object):
# came after our most recent selector() call, since we've
# held the lock until now.
self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
self._description.check_compatible()
now = time.time()
server_descriptions = self._apply_selector(selector)
return [self.get_server_by_address(sd.address)
for sd in server_descriptions]
def select_server(self, selector, server_wait_time=None):
def select_server(self, selector, server_wait_time=common.SERVER_WAIT_TIME):
"""Like select_servers, but choose a random server if several match."""
return random.choice(self.select_servers(selector, server_wait_time))
@ -130,6 +128,33 @@ class Cluster(object):
def has_server(self, address):
return address in self._servers
def get_primary(self):
"""Return primary's address or None."""
# Implemented here in Cluster instead of MongoClient, so it can lock.
with self._lock:
cluster_type = self._description.cluster_type
if cluster_type != CLUSTER_TYPE.ReplicaSetWithPrimary:
return None
description = writable_server_selector(
self._description.known_servers)[0]
return description.address
def get_secondaries(self):
"""Return set of secondary addresses."""
# Implemented here in Cluster instead of MongoClient, so it can lock.
with self._lock:
cluster_type = self._description.cluster_type
if cluster_type not in (CLUSTER_TYPE.ReplicaSetWithPrimary,
CLUSTER_TYPE.ReplicaSetNoPrimary):
return []
descriptions = secondary_server_selector(
self._description.known_servers)
return set([d.address for d in descriptions])
def close(self):
# TODO.
raise NotImplementedError
@ -146,8 +171,22 @@ class Cluster(object):
if server:
server.pool.reset()
def reset_server(self, address):
with self._lock:
server = self._servers.get(address)
if server:
server.pool.reset()
# Mark this server Unknown.
self._description = self._description.reset_server(address)
self._update_servers()
def reset(self):
"""Reset all pools and rediscover all servers."""
"""Reset all pools and disconnect from all servers.
The cluster reconnects on demand, or after common.HEARTBEAT_FREQUENCY
seconds.
"""
with self._lock:
for server in self._servers.values():
server.pool.reset()
@ -156,8 +195,6 @@ class Cluster(object):
self._description = self._description.reset()
self._update_servers()
self._request_check_all()
@property
def description(self):
return self._description

View File

@ -73,6 +73,33 @@ class ClusterDescription(object):
common.MIN_SUPPORTED_WIRE_VERSION,
common.MAX_SUPPORTED_WIRE_VERSION))
def reset_server(self, address):
"""A copy of this description, with one server marked Unknown."""
sds = self.server_descriptions()
# The default ServerDescription's type is Unknown.
sds[address] = ServerDescription(address)
if self._cluster_type == CLUSTER_TYPE.ReplicaSetWithPrimary:
cluster_type = _check_has_primary(sds)
else:
cluster_type = self._cluster_type
return ClusterDescription(cluster_type, sds, self._set_name)
def reset(self):
"""A copy of this description, with all servers marked Unknown."""
if self._cluster_type == CLUSTER_TYPE.ReplicaSetWithPrimary:
cluster_type = CLUSTER_TYPE.ReplicaSetNoPrimary
else:
cluster_type = self._cluster_type
# The default ServerDescription's type is Unknown.
sds = dict((address, ServerDescription(address))
for address in self._server_descriptions)
return ClusterDescription(cluster_type, sds, self._set_name)
def server_descriptions(self):
"""Dict of (address, ServerDescription)."""
return self._server_descriptions.copy()
@ -165,7 +192,7 @@ def updated_cluster_description(cluster_description, server_description):
cluster_type = _check_has_primary(sds)
elif server_type == SERVER_TYPE.RSPrimary:
cluster_type = _update_rs_from_primary(
cluster_type, set_name = _update_rs_from_primary(
sds, set_name, server_description)
elif server_type in (

View File

@ -374,11 +374,6 @@ class Collection(common.BaseObject):
.. mongodoc:: insert
"""
client = self.database.connection
# Batch inserts require us to know the connected primary's
# max_bson_size, max_message_size, and max_write_batch_size.
# We have to be connected to the primary to know that.
client._ensure_connected(True)
docs = doc_or_docs
return_one = False
if isinstance(docs, collections.MutableMapping):
@ -410,7 +405,7 @@ class Collection(common.BaseObject):
concern = kwargs or self.write_concern
safe = concern.get("w") != 0
if client.max_wire_version > 1 and safe:
if client._writable_max_wire_version() > 1 and safe:
# Insert command
command = SON([('insert', self.name),
('ordered', not continue_on_error)])
@ -539,10 +534,6 @@ class Collection(common.BaseObject):
if not isinstance(upsert, bool):
raise TypeError("upsert must be an instance of bool")
client = self.database.connection
# Need to connect to know the wire version, and may want to connect
# before applying SON manipulators.
client._ensure_connected(True)
if manipulate:
document = self.__database._fix_incoming(document, self)
@ -559,7 +550,8 @@ class Collection(common.BaseObject):
if first.startswith('$'):
check_keys = False
if client.max_wire_version > 1 and safe:
client = self.database.connection
if client._writable_max_wire_version() > 1 and safe:
# Update command
command = SON([('update', self.name)])
if concern:
@ -686,10 +678,7 @@ class Collection(common.BaseObject):
safe = concern.get("w") != 0
client = self.database.connection
# Need to connect to know the wire version.
client._ensure_connected(True)
if client.max_wire_version > 1 and safe:
if client._writable_max_wire_version() > 1 and safe:
# Delete command
command = SON([('delete', self.name)])
if concern:

View File

@ -92,7 +92,7 @@ class CommandCursor(object):
client = self.__collection.database.connection
try:
response = client._send_message_with_response(
msg, _connection_to_use=self.__conn_id)
msg, address=self.__conn_id)
except AutoReconnect:
# Don't try to send kill cursors on another socket
# or to another server. It can cause a _pinValue

View File

@ -31,7 +31,7 @@ from bson.py3compat import string_type, integer_types
# Defaults until we connect to a server and get updated limits.
MAX_BSON_SIZE = 16 * (1024 ** 2)
MAX_MESSAGE_SIZE = 2 * MAX_BSON_SIZE # TODO: remove.
MAX_MESSAGE_SIZE = 2 * MAX_BSON_SIZE
MIN_WIRE_VERSION = 0
MAX_WIRE_VERSION = 0
MAX_WRITE_BATCH_SIZE = 1000
@ -43,6 +43,12 @@ MAX_SUPPORTED_WIRE_VERSION = 2
# Frequency to call ismaster on servers, in seconds.
HEARTBEAT_FREQUENCY = 10
# How long to wait, in seconds, for a suitable server to be found before
# aborting an operation. For example, if the client attempts an insert
# during a replica set election, SERVER_WAIT_TIME governs the longest it
# is willing to wait for a new primary to be found.
SERVER_WAIT_TIME = 5
# Spec requires at least 10ms between ismaster calls.
MIN_HEARTBEAT_INTERVAL = 0.01
@ -279,7 +285,6 @@ VALIDATORS = {
'readpreferencetags': validate_read_preference_tags,
'latencythresholdms': validate_positive_float,
'secondaryacceptablelatencyms': validate_positive_float,
'auto_start_request': validate_boolean,
'authmechanism': validate_auth_mechanism,
'authsource': validate_string,
'gssapiservicename': validate_string,

View File

@ -864,7 +864,7 @@ class Cursor(object):
"exhaust": self.__exhaust,
}
if self.__connection_id is not None:
kwargs["_connection_to_use"] = self.__connection_id
kwargs["address"] = self.__connection_id
try:
response = client._send_message_with_response(message, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -1,127 +0,0 @@
# Copyright 2009-2014 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""TODO: help string."""
import random
import threading
from bson.py3compat import (string_type)
from pymongo import (database,
monitor,
pool,
uri_parser, ReadPreference)
from pymongo.cluster import Cluster
from pymongo.cluster_description import CLUSTER_TYPE
from pymongo.errors import (ConfigurationError)
from pymongo.server_selectors import (writable_server_selector,
secondary_server_selector)
from pymongo.settings import ClusterSettings
class MongoClientNew(object):
"""Connection to one or more MongoDB servers.
"""
def __init__(
self,
host='localhost',
port=27017,
replicaSet=None,
):
"""TODO: docstring"""
if isinstance(host, string_type):
host = [host]
if not isinstance(port, int):
raise TypeError("port must be an instance of int")
seeds = set()
for entity in host:
seeds.update(uri_parser.split_hosts(entity, port))
if not seeds:
raise ConfigurationError("need to specify at least one host")
cluster_settings = ClusterSettings(
seeds=seeds,
set_name=replicaSet,
pool_class=pool.Pool,
monitor_class=monitor.Monitor,
condition_class=threading.Condition)
# TODO: parse URI, socket timeouts, ssl args, auth,
# pool_class, document_class, pool options, condition_class,
# default database.
self._cluster = Cluster(cluster_settings)
self._cluster.open()
# TODO: these are here to fake the old MongoClient's API for the sake
# of existing Database, Collection, and Cursor.
self.read_preference = ReadPreference.PRIMARY
self.uuid_subtype = 4
self.write_concern = {}
self.document_class = dict
self.tz_aware = False
@property
def is_mongos(self):
return self._cluster.description.cluster_type == CLUSTER_TYPE.Mongos
# TODO: Remove. Database, Collection, etc. should use Cluster.
def _send_message_with_response(
self,
msg,
read_preference=ReadPreference.PRIMARY,
exhaust=False):
"""Send a message to MongoDB and return a Response object.
:Parameters:
- `msg`: (request_id, data) pair making up the message to send.
- `read_preference`: A ReadPreference.
- `exhaust`: True for an exhaust cursor's initial query.
"""
request_id, data, max_doc_size = msg
# TODO: real read preferences.
if read_preference == ReadPreference.PRIMARY:
servers = self._cluster.select_servers(writable_server_selector)
else:
servers = self._cluster.select_servers(secondary_server_selector)
server = random.choice(servers)
return server.send_message_with_response(data, request_id, exhaust)
def __getattr__(self, name):
"""Get a database by name.
Raises :class:`~pymongo.errors.InvalidName` if an invalid
database name is used.
:Parameters:
- `name`: the name of the database to get
"""
return database.Database(self, name)
def __getitem__(self, name):
"""Get a database by name.
Raises :class:`~pymongo.errors.InvalidName` if an invalid
database name is used.
:Parameters:
- `name`: the name of the database to get
"""
return self.__getattr__(name)

View File

@ -79,7 +79,7 @@ class Monitor(threading.Thread):
self.close()
else:
start = time.time() # TODO: monotonic.
self._event.wait(self._settings.heartbeat_frequency)
self._event.wait(common.HEARTBEAT_FREQUENCY)
self._event.clear()
wait_time = time.time() - start
if wait_time < common.MIN_HEARTBEAT_INTERVAL:

View File

@ -20,9 +20,15 @@ import time
import threading
import weakref
from pymongo import thread_util
from bson.py3compat import u
from pymongo import helpers, message, thread_util
from pymongo.errors import ConnectionFailure
# If the first getaddrinfo call of this interpreter's life is on a thread,
# while the main thread holds the import lock, getaddrinfo deadlocks trying
# to import the IDNA codec. Import it here, where presumably we're on the
# main thread, to avoid the deadlock. See PYTHON-607.
u('foo').encode('idna')
try:
from ssl import match_hostname, CertificateError
@ -152,8 +158,24 @@ class SocketInfo(object):
# Are we being used by an exhaust cursor?
self._exhaust = False
def command(self, dbname, spec):
"""Execute a command over the socket, or raise socket.error.
:Parameters:
- `dbname`: name of the database on which to run the command
- `spec`: a command document as a dict, SON, or mapping object
"""
# TODO: command should already be encoded.
request_id, msg, _ = message.query(0, dbname + '.$cmd', 0, -1, spec)
self.send_message(msg)
response = self.receive_message(1, request_id)
unpacked = helpers._unpack_response(response)['data'][0]
msg = "command %r failed: %%s" % spec
helpers._check_command_response(unpacked, None, msg)
return unpacked
def send_message(self, message):
"""Send a raw BSON message.
"""Send a raw BSON message or raise socket.error.
If a network exception is raised, the socket is closed.
"""
@ -164,7 +186,7 @@ class SocketInfo(object):
raise
def receive_message(self, operation, request_id):
"""Receive a raw BSON message.
"""Receive a raw BSON message or raise socket.error.
If any exception is raised, the socket is closed.
"""

View File

@ -14,6 +14,9 @@
"""Communicate with one MongoDB server in a cluster."""
import socket
from pymongo.errors import AutoReconnect, DocumentTooLarge
from pymongo.server_type import SERVER_TYPE
from pymongo.response import Response, ExhaustResponse
@ -38,29 +41,61 @@ class Server(object):
"""Check the server's state soon."""
self._monitor.request_check()
def send_message_with_response(self, message, request_id, exhaust):
"""Send a message to MongoDB and return a Response object.
def send_message(self, message):
"""Send an unacknowledged message to MongoDB.
Can raise ConnectionFailure.
:Parameters:
- `message`: BSON bytes.
- `message`: (request_id, data).
- `request_id`: A number.
- `exhaust`: If True, the socket used stays checked out. It is
returned along with its Pool in the Response.
"""
with self._pool.get_socket() as sock_info:
sock_info.exhaust(exhaust)
sock_info.send_message(message)
response_data = sock_info.receive_message(1, request_id)
if exhaust:
return ExhaustResponse(
data=response_data,
address=self._description.address,
socket_info=sock_info,
pool=self._pool)
else:
return Response(
data=response_data,
address=self._description.address)
request_id, data = self._check_bson_size(message)
try:
with self._pool.get_socket() as sock_info:
sock_info.send_message(data)
except socket.error as exc:
raise AutoReconnect(str(exc))
def send_message_with_response(self, message, exhaust=False):
"""Send a message to MongoDB and return a Response object.
Can raise ConnectionFailure.
:Parameters:
- `message`: (request_id, data, max_doc_size) or (request_id, data).
- `request_id`: A number.
- `exhaust` (optional): If True, the socket used stays checked out.
It is returned along with its Pool in the Response.
"""
request_id, data = self._check_bson_size(message)
try:
with self._pool.get_socket() as sock_info:
sock_info.exhaust(exhaust)
sock_info.send_message(data)
response_data = sock_info.receive_message(1, request_id)
if exhaust:
return ExhaustResponse(
data=response_data,
address=self._description.address,
socket_info=sock_info,
pool=self._pool)
else:
return Response(
data=response_data,
address=self._description.address)
except socket.error as exc:
raise AutoReconnect(str(exc))
def start_request(self):
# TODO: Remove implicit threadlocal requests, use explicit requests.
self.pool.start_request()
def in_request(self):
return self.pool.in_request()
def end_request(self):
self.pool.end_request()
@property
def description(self):
@ -75,6 +110,27 @@ class Server(object):
def pool(self):
return self._pool
def _check_bson_size(self, message):
"""Make sure the message doesn't include BSON documents larger
than the server will accept.
:Parameters:
- `message`: (request_id, data, max_doc_size) or (request_id, data)
Returns request_id, data.
"""
if len(message) == 3:
request_id, data, max_doc_size = message
if max_doc_size > self.description.max_bson_size:
raise DocumentTooLarge(
"BSON document too large (%d bytes) - the connected server"
"supports BSON document sizes up to %d bytes." %
(max_doc_size, self.description.max_bson_size))
return request_id, data
else:
# get_more and kill_cursors messages don't include BSON documents.
return message
def __str__(self):
d = self._description
return '<Server "%s:%s" %s>' % (

View File

@ -27,12 +27,10 @@ class ClusterSettings(object):
self,
seeds=None,
set_name=None,
server_wait_time=None,
pool_class=None,
pool_options=None,
monitor_class=monitor.Monitor,
condition_class=threading.Condition,
heartbeat_frequency=common.HEARTBEAT_FREQUENCY,
):
"""Represent MongoClient's configuration.
@ -40,12 +38,10 @@ class ClusterSettings(object):
"""
self._seeds = seeds or [('localhost', 27017)]
self._set_name = set_name
self._server_wait_time = server_wait_time or 5 # Seconds.
self._pool_class = pool_class or pool.Pool
self._pool_options = pool_options or PoolOptions()
self._monitor_class = monitor_class or monitor.Monitor
self._condition_class = condition_class or threading.Condition
self._heartbeat_frequency = heartbeat_frequency
self._direct = (len(self._seeds) == 1 and not set_name)
@property
@ -57,10 +53,6 @@ class ClusterSettings(object):
def set_name(self):
return self._set_name
@property
def server_wait_time(self):
return self._server_wait_time
@property
def pool_class(self):
return self._pool_class
@ -77,10 +69,6 @@ class ClusterSettings(object):
def condition_class(self):
return self._condition_class
@property
def heartbeat_frequency(self):
return self._heartbeat_frequency
@property
def direct(self):
"""Connect directly to a single server, or use a set of servers?

View File

@ -118,36 +118,6 @@ class Counter(object):
return self._counters.get(self.ident.get(), 0)
class Future(object):
"""Minimal backport of concurrent.futures.Future.
event_class makes this Future adaptable for async clients.
"""
def __init__(self, event_class):
self._event = event_class()
self._result = None
self._exception = None
def set_result(self, result):
self._result = result
self._event.set()
def set_exception(self, exc):
if hasattr(exc, 'with_traceback'):
# Python 3: avoid potential reference cycle.
self._exception = exc.with_traceback(None)
else:
self._exception = exc
self._event.set()
def result(self):
self._event.wait()
if self._exception:
raise self._exception
else:
return self._result
### Begin backport from CPython 3.2 for timeout support for Semaphore.acquire
class Semaphore:

View File

@ -30,6 +30,7 @@ from functools import wraps
import pymongo
from bson.py3compat import _unicode
from pymongo import common
from test.version import Version
# hostnames retrieved by MongoReplicaSetClient from isMaster will be of unicode
@ -64,7 +65,9 @@ class ClientContext(object):
self.test_commands_enabled = False
self.is_mongos = False
try:
self.client = pymongo.MongoClient(host, port)
client = pymongo.MongoClient(host, port)
client.admin.command('ismaster') # Can we connect?
self.client = client
except pymongo.errors.ConnectionFailure:
self.client = None
else:
@ -211,6 +214,58 @@ class IntegrationTest(unittest.TestCase):
pass
class client_knobs(object):
def __init__(self, heartbeat_frequency=None, server_wait_time=None):
self.heartbeat_frequency = heartbeat_frequency
self.server_wait_time = server_wait_time
self.old_heartbeat_frequency = None
self.old_server_wait_time = None
def enable(self):
self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY
self.old_server_wait_time = common.SERVER_WAIT_TIME
if self.heartbeat_frequency is not None:
common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency
if self.server_wait_time is not None:
common.SERVER_WAIT_TIME = self.server_wait_time
def __enter__(self):
self.enable()
def disable(self):
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
common.SERVER_WAIT_TIME = self.old_server_wait_time
def __exit__(self, exc_type, exc_val, exc_tb):
self.disable()
class MockClientTest(unittest.TestCase):
"""Base class for TestCases that use MockClient.
This class is *not* an IntegrationTest: if properly written, MockClient
tests do not require a running server.
The class temporarily overrides HEARTBEAT_FREQUENCY and SERVER_WAIT_TIME
to speed up tests.
"""
def setUp(self):
super(MockClientTest, self).setUp()
self.client_knobs = client_knobs(
heartbeat_frequency=0.001,
server_wait_time=0.1)
self.client_knobs.enable()
def tearDown(self):
self.client_knobs.disable()
super(MockClientTest, self).tearDown()
def connection_string(seeds=[pair]):
if client_context.auth_enabled:
return "mongodb://%s:%s@%s" % (db_user, db_pwd, ','.join(seeds))

View File

@ -15,13 +15,16 @@
"""Tools for mocking parts of PyMongo to test other parts."""
import socket
from functools import partial
from pymongo import common
from pymongo import MongoClient, MongoReplicaSetClient
from pymongo import MongoClient
from pymongo.ismaster import IsMaster
from pymongo.monitor import Monitor
from pymongo.pool import Pool, PoolOptions
from pymongo.server_description import ServerDescription
from test import host as default_host, port as default_port
from test.utils import my_partial
class MockPool(Pool):
@ -51,9 +54,45 @@ class MockPool(Pool):
return sock_info
class MockClientBase(object):
def __init__(self, standalones, members, mongoses, config):
"""standalones, etc., are like ['a:1', 'b:2']"""
class MockMonitor(Monitor):
def __init__(
self,
client,
server_description,
cluster,
pool,
cluster_settings):
# MockMonitor gets a 'client' arg, regular monitors don't.
self.client = client
self.mock_address = server_description.address
# Actually connect to the default server.
Monitor.__init__(
self,
ServerDescription((default_host, default_port)),
cluster,
pool,
cluster_settings)
def _check_once(self):
try:
response = self.client.mock_is_master('%s:%d' % self.mock_address)
return ServerDescription(self.mock_address, IsMaster(response))
except socket.error:
return None
class MockClient(MongoClient):
def __init__(
self, standalones, members, mongoses, ismaster_hosts=None,
*args, **kwargs):
"""A MongoClient connected to the default server, with a mock cluster.
standalones, members, mongoses determine the configuration of the
cluster. They are formatted like ['a:1', 'b:2']. ismaster_hosts
provides an alternative host list for the server's mocked ismaster
response; see test_connect_with_internal_ips.
"""
self.mock_standalones = standalones[:]
self.mock_members = members[:]
@ -62,8 +101,8 @@ class MockClientBase(object):
else:
self.mock_primary = None
if config is not None:
self.mock_ismaster_hosts = config
if ismaster_hosts is not None:
self.mock_ismaster_hosts = ismaster_hosts
else:
self.mock_ismaster_hosts = members[:]
@ -78,6 +117,11 @@ class MockClientBase(object):
# Hostname -> max write batch size
self.mock_max_write_batch_sizes = {}
kwargs['_pool_class'] = partial(MockPool, self)
kwargs['_monitor_class'] = partial(MockMonitor, self)
super(MockClient, self).__init__(*args, **kwargs)
def kill_host(self, host):
"""Host is like 'a:1'."""
self.mock_down_hosts.append(host)
@ -106,6 +150,7 @@ class MockClientBase(object):
if host in self.mock_standalones:
return {
'ok': 1,
'ismaster': True,
'minWireVersion': min_wire_version,
'maxWireVersion': max_wire_version,
@ -116,6 +161,7 @@ class MockClientBase(object):
# Simulate a replica set member.
response = {
'ok': 1,
'ismaster': ismaster,
'secondary': not ismaster,
'setName': 'rs',
@ -131,6 +177,7 @@ class MockClientBase(object):
if host in self.mock_mongoses:
return {
'ok': 1,
'ismaster': True,
'minWireVersion': min_wire_version,
'maxWireVersion': max_wire_version,
@ -140,49 +187,3 @@ class MockClientBase(object):
# In test_internal_ips(), we try to connect to a host listed
# in ismaster['hosts'] but not publicly accessible.
raise socket.error('Unknown host: %s' % host)
def simple_command(self, sock_info, dbname, spec):
# __simple_command is also used for authentication, but in this
# test it's only used for ismaster.
assert spec == {'ismaster': 1}
response = self.mock_is_master(
'%s:%s' % (sock_info.mock_host, sock_info.mock_port))
ping_time = 10
return response, ping_time
class MockClient(MockClientBase, MongoClient):
def __init__(
self, standalones, members, mongoses, ismaster_hosts=None,
*args, **kwargs
):
MockClientBase.__init__(
self, standalones, members, mongoses, ismaster_hosts)
kwargs['_pool_class'] = my_partial(MockPool, self)
MongoClient.__init__(self, *args, **kwargs)
def _MongoClient__simple_command(self, sock_info, dbname, spec):
return self.simple_command(sock_info, dbname, spec)
class MockReplicaSetClient(MockClientBase, MongoReplicaSetClient):
def __init__(
self, standalones, members, mongoses, ismaster_hosts=None,
*args, **kwargs
):
MockClientBase.__init__(
self, standalones, members, mongoses, ismaster_hosts)
kwargs['_pool_class'] = my_partial(MockPool, self)
MongoReplicaSetClient.__init__(self, *args, **kwargs)
def _MongoReplicaSetClient__is_master(self, host):
response = self.mock_is_master('%s:%s' % host)
connection_pool = MockPool(self, host)
ping_time = 10
return response, connection_pool, ping_time
def _MongoReplicaSetClient__simple_command(self, sock_info, dbname, spec):
return self.simple_command(sock_info, dbname, spec)

View File

@ -244,7 +244,7 @@ class TestBinary(unittest.TestCase):
def test_uri_to_uuid(self):
uri = "mongodb://foo/?uuidrepresentation=csharpLegacy"
client = MongoClient(uri, _connect=False)
client = MongoClient(uri, connect=False)
self.assertEqual(client.pymongo_test.test.uuid_subtype, CSHARP_LEGACY)
@client_context.require_connection

View File

@ -14,12 +14,14 @@
"""Test the mongo_client module."""
import contextlib
import datetime
import os
import threading
import socket
import sys
import time
import warnings
sys.path[0:0] = [""]
@ -34,8 +36,11 @@ from pymongo.errors import (AutoReconnect,
ConfigurationError,
ConnectionFailure,
InvalidName,
OperationFailure)
OperationFailure,
CursorNotFound)
from pymongo.server_selectors import writable_server_selector
from test import (client_context,
client_knobs,
connection_string,
host,
pair,
@ -44,7 +49,8 @@ from test import (client_context,
unittest,
IntegrationTest,
db_pwd,
db_user)
db_user,
MockClientTest)
from test.pymongo_mocks import MockClient
from test.utils import (assertRaisesExactly,
delay,
@ -54,9 +60,10 @@ from test.utils import (assertRaisesExactly,
TestRequestMixin,
_TestLazyConnectMixin,
lazy_client_trial,
NTHREADS,
get_pool,
one)
one,
connected,
wait_until)
class ClientUnitTest(unittest.TestCase, TestRequestMixin):
@ -64,7 +71,7 @@ class ClientUnitTest(unittest.TestCase, TestRequestMixin):
@classmethod
def setUpClass(cls):
cls.client = MongoClient(host, port, _connect=False)
cls.client = MongoClient(host, port, connect=False)
def test_types(self):
self.assertRaises(TypeError, MongoClient, 1)
@ -97,18 +104,18 @@ class ClientUnitTest(unittest.TestCase, TestRequestMixin):
self.assertRaises(TypeError, iterate)
def test_get_default_database(self):
c = MongoClient("mongodb://%s:%d/foo" % (host, port), _connect=False)
c = MongoClient("mongodb://%s:%d/foo" % (host, port), connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
def test_get_default_database_error(self):
# URI with no database.
c = MongoClient("mongodb://%s:%d/" % (host, port), _connect=False)
c = MongoClient("mongodb://%s:%d/" % (host, port), connect=False)
self.assertRaises(ConfigurationError, c.get_default_database)
def test_get_default_database_with_authsource(self):
# Ensure we distinguish database name from authSource.
uri = "mongodb://%s:%d/foo?authSource=src" % (host, port)
c = MongoClient(uri, _connect=False)
c = MongoClient(uri, connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
@ -120,19 +127,22 @@ class TestClient(IntegrationTest, TestRequestMixin):
cls.client = client_context.client
def test_constants(self):
MongoClient.HOST = host
MongoClient.PORT = port
self.assertTrue(MongoClient())
# Set bad defaults.
MongoClient.HOST = "somedomainthatdoesntexist.org"
MongoClient.PORT = 123456789
assertRaisesExactly(
ConnectionFailure, MongoClient, connectTimeoutMS=600)
self.assertTrue(MongoClient(host, port))
with client_knobs(server_wait_time=0.01):
with self.assertRaises(AutoReconnect):
connected(MongoClient())
# Override the defaults. No error.
connected(MongoClient(host, port))
# Set good defaults.
MongoClient.HOST = host
MongoClient.PORT = port
self.assertTrue(MongoClient())
# No error.
connected(MongoClient())
def assertIsInstance(self, obj, cls, msg=None):
"""Backport from Python 2.7."""
@ -141,13 +151,12 @@ class TestClient(IntegrationTest, TestRequestMixin):
self.fail(self._formatMessage(msg, standardMsg))
def test_init_disconnected(self):
c = MongoClient(host, port, _connect=False)
c = MongoClient(host, port, connect=False)
self.assertIsInstance(c.is_primary, bool)
self.assertIsInstance(c.is_mongos, bool)
self.assertIsInstance(c.max_pool_size, int)
self.assertIsInstance(c.nodes, frozenset)
self.assertIsInstance(c.auto_start_request, bool)
self.assertEqual(dict, c.get_document_class())
self.assertIsInstance(c.tz_aware, bool)
self.assertIsInstance(c.max_bson_size, int)
@ -168,36 +177,29 @@ class TestClient(IntegrationTest, TestRequestMixin):
self.assertTrue(c.min_wire_version >= 0)
bad_host = "somedomainthatdoesntexist.org"
c = MongoClient(bad_host, port, connectTimeoutMS=1, _connect=False)
with client_knobs(server_wait_time=0.01):
c = MongoClient(bad_host, port)
self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one)
def test_init_disconnected_with_auth(self):
uri = "mongodb://user:pass@somedomainthatdoesntexist"
c = MongoClient(uri, connectTimeoutMS=1, _connect=False)
with client_knobs(server_wait_time=0.01):
c = MongoClient(uri)
self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one)
def test_connect(self):
# Check that the exception is a ConnectionFailure, not a subclass like
# AutoReconnect
assertRaisesExactly(
ConnectionFailure, MongoClient,
"somedomainthatdoesntexist.org", connectTimeoutMS=600)
assertRaisesExactly(
ConnectionFailure, MongoClient, host, 123456789)
self.assertTrue(MongoClient(host, port))
def test_equality(self):
c = connected(MongoClient(host, port))
# ClientContext.client is constructed as MongoClient(host, port)
self.assertEqual(self.client, MongoClient(host, port))
self.assertEqual(self.client, c)
# Explicitly test inequality
self.assertFalse(self.client != MongoClient(host, port))
self.assertFalse(self.client != c)
def test_host_w_port(self):
self.assertTrue(MongoClient("%s:%d" % (host, port)))
assertRaisesExactly(
ConnectionFailure, MongoClient, "%s:1234567" % (host,), port)
with client_knobs(server_wait_time=0.01):
with self.assertRaises(AutoReconnect):
connected(MongoClient("%s:1234567"))
def test_repr(self):
# Making host a str avoids the 'u' prefix in Python 2, so the repr is
@ -326,8 +328,9 @@ class TestClient(IntegrationTest, TestRequestMixin):
coll.count()
def test_from_uri(self):
self.assertEqual(self.client,
MongoClient("mongodb://%s:%d" % (host, port)))
self.assertEqual(
self.client,
connected(MongoClient("mongodb://%s:%d" % (host, port))))
@client_context.require_auth
def test_auth_from_uri(self):
@ -356,12 +359,12 @@ class TestClient(IntegrationTest, TestRequestMixin):
# Auth with lazy connection.
MongoClient(
"mongodb://user:pass@%s:%d/pymongo_test" % (host, port),
_connect=False).pymongo_test.test.find_one()
connect=False).pymongo_test.test.find_one()
# Wrong password.
bad_client = MongoClient(
"mongodb://user:wrong@%s:%d/pymongo_test" % (host, port),
_connect=False)
connect=False)
self.assertRaises(OperationFailure,
bad_client.pymongo_test.test.find_one)
@ -375,7 +378,7 @@ class TestClient(IntegrationTest, TestRequestMixin):
def test_lazy_auth_raises_operation_failure(self):
lazy_client = MongoClient(
"mongodb://user:wrong@%s:%d/pymongo_test" % (host, port),
_connect=False)
connect=False)
assertRaisesExactly(
OperationFailure, lazy_client.test.collection.find_one)
@ -392,7 +395,8 @@ class TestClient(IntegrationTest, TestRequestMixin):
if not os.access(mongodb_socket, os.R_OK):
raise SkipTest("Socket file is not accessible")
self.assertTrue(MongoClient("mongodb://%s" % mongodb_socket))
# No error.
connected(MongoClient("mongodb://%s" % mongodb_socket))
client = MongoClient("mongodb://%s" % mongodb_socket)
client.pymongo_test.test.save({"dummy": "object"})
@ -402,8 +406,9 @@ class TestClient(IntegrationTest, TestRequestMixin):
self.assertTrue("pymongo_test" in dbs)
# Confirm it fails with a missing socket
self.assertRaises(ConnectionFailure, MongoClient,
"mongodb:///tmp/none-existent.sock")
self.assertRaises(
ConnectionFailure,
connected, MongoClient("mongodb:///tmp/non-existent.sock"))
def test_fork(self):
# Test using a client before and after a fork.
@ -473,15 +478,6 @@ class TestClient(IntegrationTest, TestRequestMixin):
self.assertTrue(isinstance(db.test.find_one(), dict))
self.assertFalse(isinstance(db.test.find_one(), SON))
c.document_class = SON
try:
self.assertEqual(SON, c.document_class)
self.assertTrue(isinstance(db.test.find_one(), SON))
self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON))
finally:
c.document_class = dict
c = get_client(pair, document_class=SON)
db = c.pymongo_test
@ -489,11 +485,13 @@ class TestClient(IntegrationTest, TestRequestMixin):
self.assertTrue(isinstance(db.test.find_one(), SON))
self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON))
c.document_class = dict
# document_class is read-only in PyMongo 3.0.
with self.assertRaises(AttributeError):
c.document_class = dict
self.assertEqual(dict, c.document_class)
self.assertTrue(isinstance(db.test.find_one(), dict))
self.assertFalse(isinstance(db.test.find_one(), SON))
with warnings.catch_warnings():
warnings.simplefilter("error", DeprecationWarning)
self.assertRaises(DeprecationWarning, c.get_document_class)
def test_timeouts(self):
client = MongoClient(host, port, connectTimeoutMS=10500)
@ -503,10 +501,10 @@ class TestClient(IntegrationTest, TestRequestMixin):
def test_socket_timeout_ms_validation(self):
c = get_client(pair, socketTimeoutMS=10 * 1000)
self.assertEqual(10, c._MongoClient__pool_opts.socket_timeout)
self.assertEqual(10, get_pool(c).opts.socket_timeout)
c = get_client(pair, socketTimeoutMS=None)
self.assertEqual(None, c._MongoClient__pool_opts.socket_timeout)
c = connected(get_client(pair, socketTimeoutMS=None))
self.assertEqual(None, get_pool(c).opts.socket_timeout)
self.assertRaises(ConfigurationError,
get_client, pair, socketTimeoutMS=0)
@ -564,24 +562,20 @@ class TestClient(IntegrationTest, TestRequestMixin):
naive.pymongo_test.test.find_one()["x"])
def test_ipv6(self):
try:
MongoClient("[::1]")
except:
# Either mongod was started without --ipv6
# or the OS doesn't support it (or both).
raise SkipTest("No IPv6")
# Try a few simple things
MongoClient("mongodb://[::1]:%d" % (port,))
MongoClient("mongodb://[::1]:%d/?w=0" % (port,))
MongoClient("[::1]:%d,localhost:%d" % (port, port))
with client_knobs(server_wait_time=0.01):
try:
connected(MongoClient("[::1]"))
except:
# Either mongod was started without --ipv6
# or the OS doesn't support it (or both).
raise SkipTest("No IPv6")
if client_context.auth_enabled:
auth_str = "%s:%s@" % (db_user, db_pwd)
else:
auth_str = ""
uri = "mongodb://%slocalhost:%d,[::1]:%d" % (auth_str, port, port)
uri = "mongodb://%s[::1]:%d" % (auth_str, port)
client = MongoClient(uri)
client.pymongo_test.test.save({"dummy": u("object")})
client.pymongo_test_bernie.test.save({"dummy": u("object")})
@ -616,9 +610,7 @@ class TestClient(IntegrationTest, TestRequestMixin):
self.assertFalse(locked)
def test_contextlib(self):
import contextlib
client = get_client(pair, auto_start_request=False)
client = get_client(pair)
client.pymongo_test.drop_collection("test")
client.pymongo_test.test.insert({"foo": "bar"})
@ -628,11 +620,12 @@ class TestClient(IntegrationTest, TestRequestMixin):
with contextlib.closing(client):
self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"])
self.assertEqual(None, client._MongoClient__member)
self.assertEqual(1, len(get_pool(client).sockets))
self.assertEqual(0, len(get_pool(client).sockets))
with self.client as client:
self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"])
self.assertEqual(None, client._MongoClient__member)
self.assertEqual(0, len(get_pool(client).sockets))
def test_with_start_request(self):
pool = get_pool(self.client)
@ -667,37 +660,7 @@ class TestClient(IntegrationTest, TestRequestMixin):
self.assertNoRequest(pool)
self.assertDifferentSock(pool)
def test_auto_start_request(self):
for bad_horrible_value in (None, 5, 'hi!'):
self.assertRaises(
(TypeError, ConfigurationError),
lambda: get_client(pair, auto_start_request=bad_horrible_value)
)
# auto_start_request should default to False
self.assertFalse(self.client.auto_start_request)
client = get_client(pair, auto_start_request=True)
self.assertTrue(client.auto_start_request)
# Assure we acquire a request socket.
client.pymongo_test.test.find_one()
self.assertTrue(client.in_request())
pool = get_pool(client)
self.assertRequestSocket(pool)
self.assertSameSock(pool)
client.end_request()
self.assertNoRequest(pool)
self.assertDifferentSock(pool)
# Trigger auto_start_request
client.pymongo_test.test.find_one()
self.assertRequestSocket(pool)
self.assertSameSock(pool)
def test_nested_request(self):
# auto_start_request is False
pool = get_pool(self.client)
self.assertFalse(self.client.in_request())
@ -830,7 +793,8 @@ class TestClient(IntegrationTest, TestRequestMixin):
def test_operation_failure_with_request(self):
# Ensure MongoClient doesn't close socket after it gets an error
# response to getLastError. PYTHON-395.
c = get_client(pair, auto_start_request=True)
c = get_client(pair)
c.start_request()
pool = get_pool(c)
# Pool reserves a socket for this thread.
@ -850,98 +814,42 @@ class TestClient(IntegrationTest, TestRequestMixin):
def test_alive(self):
self.assertTrue(self.client.alive())
client = MongoClient('doesnt exist', _connect=False)
client = MongoClient('doesnt exist', connect=False)
self.assertFalse(client.alive())
def test_wire_version(self):
c = MockClient(
standalones=[],
members=['a:1', 'b:2', 'c:3'],
mongoses=[],
host='b:2', # Pass a secondary.
replicaSet='rs',
_connect=False)
def test_kill_cursors(self):
self.collection = self.client.pymongo_test.test
self.collection.remove()
# Ensure two batches.
self.collection.insert({'_id': i} for i in range(200))
c.set_wire_version_range('a:1', 1, 5)
c.db.command('ismaster') # Connect.
self.assertEqual(c.min_wire_version, 1)
self.assertEqual(c.max_wire_version, 5)
c.set_wire_version_range('a:1', 10, 11)
c.disconnect()
self.assertRaises(ConfigurationError, c.db.collection.find_one)
def test_max_wire_version(self):
c = MockClient(
standalones=[],
members=['a:1', 'b:2', 'c:3'],
mongoses=[],
host='b:2', # Pass a secondary.
replicaSet='rs',
_connect=False)
c.set_max_write_batch_size('a:1', 1)
c.set_max_write_batch_size('b:2', 2)
# Starts with default max batch size.
self.assertEqual(1000, c.max_write_batch_size)
c.db.command('ismaster') # Connect.
# Uses primary's max batch size.
self.assertEqual(c.max_write_batch_size, 1)
# b becomes primary.
c.mock_primary = 'b:2'
c.disconnect()
self.assertEqual(1000, c.max_write_batch_size)
c.db.command('ismaster') # Connect.
self.assertEqual(c.max_write_batch_size, 2)
def test_wire_version_mongos_ha(self):
c = MockClient(
standalones=[],
members=[],
mongoses=['a:1', 'b:2', 'c:3'],
host='a:1,b:2,c:3',
_connect=False)
c.set_wire_version_range('a:1', 2, 5)
c.set_wire_version_range('b:2', 2, 2)
c.set_wire_version_range('c:3', 1, 1)
c.db.command('ismaster') # Connect.
# Which member did we use?
used_host = '%s:%s' % (c.host, c.port)
expected_min, expected_max = c.mock_wire_versions[used_host]
self.assertEqual(expected_min, c.min_wire_version)
self.assertEqual(expected_max, c.max_wire_version)
c.set_wire_version_range('a:1', 0, 0)
c.set_wire_version_range('b:2', 0, 0)
c.set_wire_version_range('c:3', 0, 0)
c.disconnect()
c.db.command('ismaster')
used_host = '%s:%s' % (c.host, c.port)
expected_min, expected_max = c.mock_wire_versions[used_host]
self.assertEqual(expected_min, c.min_wire_version)
self.assertEqual(expected_max, c.max_wire_version)
cursor = self.collection.find()
next(cursor)
self.client.kill_cursors([cursor.cursor_id])
with self.assertRaises(CursorNotFound):
list(cursor)
@client_context.require_replica_set
def test_replica_set(self):
name = client_context.setname
MongoClient(host, port, replicaSet=name) # No error.
connected(MongoClient(host, port, replicaSet=name)) # No error.
self.assertRaises(
ConfigurationError,
MongoClient, host, port, replicaSet='bad' + name)
with client_knobs(server_wait_time=0.01):
client = MongoClient(host, port, replicaSet='bad' + name)
with self.assertRaises(AutoReconnect):
connected(client)
def test_lazy_connect_w0(self):
client = get_client(connection_string(), _connect=False)
client = get_client(connection_string(), connect=False)
client.pymongo_test.test.insert({}, w=0)
client = get_client(connection_string(), _connect=False)
client = get_client(connection_string(), connect=False)
client.pymongo_test.test.update({}, {'$set': {'x': 1}}, w=0)
client = get_client(connection_string(), _connect=False)
client = get_client(connection_string(), connect=False)
client.pymongo_test.test.remove(w=0)
@client_context.require_no_mongos
@ -953,6 +861,9 @@ class TestClient(IntegrationTest, TestRequestMixin):
pool = get_pool(client)
pool._check_interval_seconds = None # Never check.
# Ensure a socket.
connected(client)
# Cause a network error.
sock_info = one(pool.sockets)
sock_info.sock.close()
@ -993,6 +904,90 @@ class TestClient(IntegrationTest, TestRequestMixin):
c.test.collection.find_one()
class TestClientProperties(MockClientTest):
@client_context.require_connection
def test_wire_version(self):
c = MockClient(
standalones=[],
members=['a:1', 'b:2', 'c:3'],
mongoses=[],
host='b:2', # Pass a secondary.
replicaSet='rs',
connect=False)
c.set_wire_version_range('a:1', 1, 5)
c._get_cluster().select_servers(writable_server_selector) # Connect.
self.assertEqual(c.min_wire_version, 1)
self.assertEqual(c.max_wire_version, 5)
c.set_wire_version_range('a:1', 10, 11)
c.disconnect()
c._get_cluster()
self.assertRaises(ConfigurationError, c.db.collection.find_one)
def test_max_wire_version(self):
c = MockClient(
standalones=[],
members=['a:1', 'b:2', 'c:3'],
mongoses=[],
host='b:2', # Pass a secondary.
replicaSet='rs',
connect=False)
c.set_max_write_batch_size('a:1', 1)
c.set_max_write_batch_size('b:2', 2)
# Starts with default max batch size.
self.assertEqual(1000, c.max_write_batch_size)
c._get_cluster()
wait_until(lambda: len(c.nodes) == 3, 'connect')
# Uses primary's max batch size.
self.assertEqual(c.max_write_batch_size, 1)
# b becomes primary.
c.mock_primary = 'b:2'
c.disconnect()
self.assertEqual(1000, c.max_write_batch_size)
c._get_cluster()
wait_until(lambda: len(c.nodes) == 3, 'connect')
self.assertEqual(c.max_write_batch_size, 2)
def test_wire_version_mongos_ha(self):
# TODO: Reimplement Mongos HA with PyMongo 3's MongoClient.
raise SkipTest('Mongos HA must be reimplemented in PyMongo 3')
c = MockClient(
standalones=[],
members=[],
mongoses=['a:1', 'b:2', 'c:3'],
host='a:1,b:2,c:3',
connect=False)
c.set_wire_version_range('a:1', 2, 5)
c.set_wire_version_range('b:2', 2, 2)
c.set_wire_version_range('c:3', 1, 1)
c.db.command('ismaster') # Connect.
# Which member did we use?
used_host = '%s:%s' % (c.host, c.port)
expected_min, expected_max = c.mock_wire_versions[used_host]
self.assertEqual(expected_min, c.min_wire_version)
self.assertEqual(expected_max, c.max_wire_version)
c.set_wire_version_range('a:1', 0, 0)
c.set_wire_version_range('b:2', 0, 0)
c.set_wire_version_range('c:3', 0, 0)
c.disconnect()
c.db.command('ismaster')
used_host = '%s:%s' % (c.host, c.port)
expected_min, expected_max = c.mock_wire_versions[used_host]
self.assertEqual(expected_min, c.min_wire_version)
self.assertEqual(expected_max, c.max_wire_version)
class TestClientLazyConnect(IntegrationTest, _TestLazyConnectMixin):
def _get_client(self, **kwargs):
@ -1019,63 +1014,40 @@ class TestClientLazyConnectBadSeeds(IntegrationTest):
client = collection.database.connection
self.assertEqual(0, len(client.nodes))
lazy_client_trial(reset, connect, test, self._get_client)
with client_knobs(server_wait_time=0.01):
lazy_client_trial(reset, connect, test, self._get_client)
class TestClientLazyConnectOneGoodSeed(
IntegrationTest,
_TestLazyConnectMixin):
def _get_client(self, **kwargs):
kwargs.setdefault('connectTimeoutMS', 100)
# Assume there are no open mongods listening on a.com, b.com, ....
bad_seeds = ['%s.com' % chr(ord('a') + i) for i in range(10)]
seeds = bad_seeds + [pair]
# MongoClient puts the seeds in a set before iterating, so order is
# undefined.
return get_client(connection_string(seeds=seeds), **kwargs)
def test_insert(self):
def reset(collection):
collection.drop()
def insert(collection, dummy):
collection.insert({})
def test(collection):
self.assertEqual(NTHREADS, collection.count())
lazy_client_trial(reset, insert, test, self._get_client)
class TestMongoClientFailover(IntegrationTest):
class TestMongoClientFailover(MockClientTest):
def test_discover_primary(self):
c = MockClient(
standalones=[],
members=['a:1', 'b:2', 'c:3'],
mongoses=[],
host='b:2', # Pass a secondary.
replicaSet='rs')
# Disable background refresh.
with client_knobs(heartbeat_frequency=9999999):
c = MockClient(
standalones=[],
members=['a:1', 'b:2', 'c:3'],
mongoses=[],
host='b:2', # Pass a secondary.
replicaSet='rs')
self.assertEqual('a', c.host)
self.assertEqual(1, c.port)
self.assertEqual(3, len(c.nodes))
wait_until(lambda: len(c.nodes) == 3, 'connect')
self.assertEqual('a', c.host)
self.assertEqual(1, c.port)
# Fail over.
c.kill_host('a:1')
c.mock_primary = 'b:2'
# Fail over.
c.kill_host('a:1')
c.mock_primary = 'b:2'
# Force reconnect.
c.disconnect()
c.db.command('ismaster')
self.assertEqual('b', c.host)
self.assertEqual(2, c.port)
c.disconnect()
self.assertEqual(0, len(c.nodes))
# a:1 is still in nodes.
self.assertEqual(3, len(c.nodes))
c._get_cluster().select_servers(writable_server_selector)
self.assertEqual('b', c.host)
self.assertEqual(2, c.port)
# a:1 not longer in nodes.
self.assertLess(len(c.nodes), 3)
wait_until(lambda: len(c.nodes) == 2, 'discover node "c"')
def test_reconnect(self):
# Verify the node list isn't forgotten during a network failure.
@ -1086,6 +1058,8 @@ class TestMongoClientFailover(IntegrationTest):
host='b:2', # Pass a secondary.
replicaSet='rs')
wait_until(lambda: len(c.nodes) == 3, 'connect')
# Total failure.
c.kill_host('a:1')
c.kill_host('b:2')
@ -1094,12 +1068,11 @@ class TestMongoClientFailover(IntegrationTest):
# MongoClient discovers it's alone.
self.assertRaises(AutoReconnect, c.db.collection.find_one)
# But it remembers its node list.
self.assertEqual(3, len(c.nodes))
# So it can reconnect.
# But it can reconnect.
c.revive_host('a:1')
c.db.command('ismaster')
c._get_cluster().select_servers(writable_server_selector)
self.assertEqual('a', c.host)
self.assertEqual(1, c.port)
if __name__ == "__main__":

View File

@ -1,72 +0,0 @@
# Copyright 2009-2014 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the mongo_client module."""
import sys
sys.path[0:0] = [""]
from pymongo import ReadPreference
from pymongo.mongo_client_new import MongoClientNew
from test import host, port, unittest, IntegrationTest, client_context, SkipTest
from test.utils import one
class TestClientNew(IntegrationTest):
def test_buildinfo(self):
c = MongoClientNew(host, port)
assert 'version' in c.admin.command('buildinfo')
def test_ismaster(self):
c = MongoClientNew(['%s:%d' % (host, port)])
response = c.admin.command('ismaster')
self.assertTrue('ismaster' in response)
@client_context.require_replica_set
def test_ismaster_primary(self):
primary_pair = client_context.rs_client.primary
c = MongoClientNew(['%s:%d' % primary_pair],
replicaSet=client_context.setname)
response = c.admin.command('ismaster')
self.assertTrue(response['ismaster'])
@client_context.require_replica_set
def test_ismaster_secondary(self):
secondary_pair = one(client_context.rs_client.secondaries)
c = MongoClientNew(['%s:%d' % secondary_pair],
replicaSet=client_context.setname)
# 'ismaster' does not obey read preference, so hack it.
secondary_response = c.admin['$cmd'].find_one(
{'ismaster': 1},
read_preference=ReadPreference.SECONDARY)
# Make sure we really executed ismaster on a secondary.
self.assertTrue(secondary_response.get('secondary'))
def test_find(self):
# TODO: re-enable once MongoClientNew can do auth.
if client_context.auth_enabled:
raise SkipTest("MongoClientNew can't yet do auth")
client_context.client.pymongo_test.test_collection.insert({'_id': 1})
c = MongoClientNew(host, port)
docs = list(c.pymongo_test.test_collection.find())
self.assertEqual([{'_id': 1}], docs)
if __name__ == "__main__":
unittest.main()

View File

@ -27,8 +27,7 @@ from pymongo.server_type import SERVER_TYPE
from pymongo.cluster import Cluster
from pymongo.cluster_description import CLUSTER_TYPE
from pymongo.errors import (ConfigurationError,
ConnectionFailure,
InvalidOperation)
ConnectionFailure)
from pymongo.ismaster import IsMaster
from pymongo.monitor import Monitor
from pymongo.read_preferences import MovingAverage
@ -36,7 +35,7 @@ from pymongo.server_description import ServerDescription
from pymongo.server_selectors import (any_server_selector,
writable_server_selector)
from pymongo.settings import ClusterSettings
from test import unittest
from test import unittest, client_knobs
class MockSocketInfo(object):
@ -95,8 +94,7 @@ def create_mock_cluster(seeds=None, set_name=None, monitor_class=MockMonitor):
partitioned_seeds,
set_name=set_name,
pool_class=MockPool,
monitor_class=monitor_class,
heartbeat_frequency=99999999)
monitor_class=monitor_class)
c = Cluster(cluster_settings)
c.open()
@ -122,7 +120,20 @@ def get_type(cluster, hostname):
return description.server_type
class TestSingleServerCluster(unittest.TestCase):
class ClusterTest(unittest.TestCase):
"""Disables periodic monitoring, to make tests deterministic."""
def setUp(self):
super(ClusterTest, self).setUp()
self.client_knobs = client_knobs(heartbeat_frequency=9999999)
self.client_knobs.enable()
def tearDown(self):
self.client_knobs.disable()
super(ClusterTest, self).tearDown()
class TestSingleServerCluster(ClusterTest):
def test_direct_connection(self):
for server_type, ismaster_response in [
(SERVER_TYPE.RSPrimary, {
@ -206,7 +217,7 @@ class TestSingleServerCluster(unittest.TestCase):
self.assertEqual(2, s.description.round_trip_time)
class TestMultiServerCluster(unittest.TestCase):
class TestMultiServerCluster(ClusterTest):
def test_unexpected_host(self):
# Received ismaster response from host not in cluster.
# E.g., a race where the host is removed before it responds.
@ -390,6 +401,45 @@ class TestMultiServerCluster(unittest.TestCase):
self.assertEqual(CLUSTER_TYPE.ReplicaSetWithPrimary,
c.description.cluster_type)
def test_reset_server(self):
c = create_mock_cluster(set_name='rs')
got_ismaster(c, ('a', 27017), {
'ok': 1,
'ismaster': True,
'setName': 'rs',
'hosts': ['a', 'b']})
got_ismaster(c, ('b', 27017), {
'ok': 1,
'ismaster': False,
'secondary': True,
'setName': 'rs',
'hosts': ['a', 'b']})
c.reset_server(('a', 27017))
self.assertEqual(SERVER_TYPE.Unknown, get_type(c, 'a'))
self.assertEqual(SERVER_TYPE.RSSecondary, get_type(c, 'b'))
self.assertEqual('rs', c.description.set_name)
self.assertEqual(CLUSTER_TYPE.ReplicaSetNoPrimary,
c.description.cluster_type)
got_ismaster(c, ('a', 27017), {
'ok': 1,
'ismaster': True,
'setName': 'rs',
'hosts': ['a', 'b']})
self.assertEqual(SERVER_TYPE.RSPrimary, get_type(c, 'a'))
self.assertEqual(CLUSTER_TYPE.ReplicaSetWithPrimary,
c.description.cluster_type)
c.reset_server(('b', 27017))
self.assertEqual(SERVER_TYPE.RSPrimary, get_type(c, 'a'))
self.assertEqual(SERVER_TYPE.Unknown, get_type(c, 'b'))
self.assertEqual('rs', c.description.set_name)
self.assertEqual(CLUSTER_TYPE.ReplicaSetWithPrimary,
c.description.cluster_type)
def test_discover_set_name_from_primary(self):
# Discovering a replica set without the setName supplied by the user
# is not yet supported by MongoClient, but Cluster can do it.
@ -620,7 +670,7 @@ class TestMultiServerCluster(unittest.TestCase):
self.assertEqual(2, write_batch_size())
class TestClusterErrors(unittest.TestCase):
class TestClusterErrors(ClusterTest):
# Errors when calling ismaster.
def test_pool_reset(self):

View File

@ -58,7 +58,7 @@ class TestCollectionNoConnect(unittest.TestCase):
@classmethod
def setUpClass(cls):
client = MongoClient(host, port, _connect=False)
client = MongoClient(host, port, connect=False)
cls.db = client.pymongo_test
def test_collection(self):
@ -1372,7 +1372,6 @@ class TestCollection(IntegrationTest):
def test_aggregation_cursor(self):
db = self.db
if client_context.setname:
db = client_context.rs_client[db.name]
# Test that getMore messages are sent to the right server.
db.read_preference = ReadPreference.SECONDARY
@ -1399,7 +1398,6 @@ class TestCollection(IntegrationTest):
db = self.db
db.drop_collection("test")
if client_context.setname:
db = client_context.rs_client[db.name]
# Test that getMore messages are sent to the right server.
db.read_preference = ReadPreference.SECONDARY
@ -1717,7 +1715,6 @@ class TestCollection(IntegrationTest):
client = get_client(pair, max_pool_size=1)
socks = get_pool(client).sockets
self.assertEqual(1, len(socks))
# Make sure the socket is returned after exhaustion.
cur = client[self.db.name].test.find(exhaust=True)

View File

@ -24,10 +24,9 @@ from bson.code import Code
from bson.objectid import ObjectId
from bson.son import SON
from pymongo.mongo_client import MongoClient
from pymongo.mongo_replica_set_client import MongoReplicaSetClient
from pymongo.errors import ConfigurationError, OperationFailure
from test import client_context, pair, unittest
from test.utils import get_client, get_rs_client
from test.utils import get_client, connected
@client_context.require_connection
@ -232,33 +231,11 @@ class TestCommon(unittest.TestCase):
self.assertTrue(coll.insert(doc))
# Equality tests
self.assertEqual(m, MongoClient("mongodb://%s/?w=0" % (pair,)))
self.assertFalse(m != MongoClient("mongodb://%s/?w=0" % (pair,)))
self.assertEqual(m,
connected(MongoClient("mongodb://%s/?w=0" % (pair,))))
@client_context.require_replica_set
def test_mongo_replica_set_client(self):
setname = client_context.setname
m = get_rs_client(pair, replicaSet=setname, w=0)
coll = m.pymongo_test.write_concern_test
coll.drop()
doc = {"_id": ObjectId()}
coll.insert(doc)
self.assertTrue(coll.insert(doc, w=0))
self.assertTrue(coll.insert(doc))
self.assertRaises(OperationFailure, coll.insert, doc, w=1)
m = client_context.rs_client
coll = m.pymongo_test.write_concern_test
self.assertTrue(coll.insert(doc, w=0))
self.assertRaises(OperationFailure, coll.insert, doc)
self.assertRaises(OperationFailure, coll.insert, doc, w=1)
m = MongoReplicaSetClient("mongodb://%s/?replicaSet=%s" % (pair, setname))
coll = m.pymongo_test.write_concern_test
self.assertRaises(OperationFailure, coll.insert, doc)
m = MongoReplicaSetClient("mongodb://%s/?replicaSet=%s;w=0" % (pair, setname))
coll = m.pymongo_test.write_concern_test
self.assertTrue(coll.insert(doc))
self.assertFalse(m !=
connected(MongoClient("mongodb://%s/?w=0" % (pair,))))
if __name__ == "__main__":

View File

@ -46,7 +46,7 @@ class TestCursorNoConnect(unittest.TestCase):
@classmethod
def setUpClass(cls):
client = MongoClient(host, port, _connect=False)
client = MongoClient(host, port, connect=False)
cls.db = client.test
def test_deepcopy_cursor_littered_with_regexes(self):

View File

@ -64,7 +64,7 @@ class TestDatabaseNoConnect(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.client = MongoClient(host, port, _connect=False)
cls.client = MongoClient(host, port, connect=False)
def test_name(self):
self.assertRaises(TypeError, Database, self.client, 4)

View File

@ -15,17 +15,21 @@
"""Test the errors module."""
import sys
sys.path[0:0] = [""]
from pymongo import MongoClient
from pymongo.errors import PyMongoError
from test import unittest
from test import unittest, client_knobs
from test.utils import connected
class TestErrors(unittest.TestCase):
def test_base_exception(self):
self.assertRaises(PyMongoError, MongoClient, port=0)
# connected(MongoClient(...)) with a bad port raises AutoReconnect.
with client_knobs(server_wait_time=0.01):
self.assertRaises(PyMongoError, connected, MongoClient(port=0))
if __name__ == '__main__':

View File

@ -35,14 +35,20 @@ from gridfs.errors import (NoFile,
UnsupportedAPI)
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure
from test import client_context, qcheck, unittest, host, port, IntegrationTest
from test import (client_context,
client_knobs,
IntegrationTest,
host,
port,
unittest,
qcheck)
class TestGridFileNoConnect(unittest.TestCase):
@classmethod
def setUpClass(cls):
client = MongoClient(host, port, _connect=False)
client = MongoClient(host, port, connect=False)
cls.db = client.pymongo_test
def test_grid_file(self):
@ -222,7 +228,9 @@ class TestGridFile(IntegrationTest):
def test_grid_out_default_opts(self):
self.assertRaises(TypeError, GridOut, "foo")
self.assertRaises(NoFile, GridOut, self.db.fs, 5)
gout = GridOut(self.db.fs, 5)
with self.assertRaises(NoFile):
gout.name
a = GridIn(self.db.fs)
a.close()
@ -280,7 +288,9 @@ class TestGridFile(IntegrationTest):
three = GridOut(self.db.fs, 5, file_document=self.db.fs.files.find_one())
self.assertEqual(b"foo bar", three.read())
self.assertRaises(NoFile, GridOut, self.db.fs, file_document={})
four = GridOut(self.db.fs, file_document={})
with self.assertRaises(NoFile):
four.name
def test_write_file_like(self):
one = GridIn(self.db.fs)
@ -588,23 +598,24 @@ Bye"""))
def test_grid_out_lazy_connect(self):
fs = self.db.fs
outfile = GridOut(fs, file_id=-1, _connect=False)
outfile = GridOut(fs, file_id=-1)
self.assertRaises(NoFile, outfile.read)
self.assertRaises(NoFile, getattr, outfile, 'filename')
infile = GridIn(fs, filename=1)
infile.close()
outfile = GridOut(fs, infile._id, _connect=False)
outfile = GridOut(fs, infile._id)
outfile.read()
outfile.filename
def test_grid_in_lazy_connect(self):
client = MongoClient('badhost', _connect=False)
fs = client.db.fs
infile = GridIn(fs, file_id=-1, chunk_size=1)
self.assertRaises(ConnectionFailure, infile.write, b'data goes here')
self.assertRaises(ConnectionFailure, infile.close)
with client_knobs(server_wait_time=0.01):
client = MongoClient('badhost', connect=False)
fs = client.db.fs
infile = GridIn(fs, file_id=-1, chunk_size=1)
self.assertRaises(ConnectionFailure, infile.write, b'data')
self.assertRaises(ConnectionFailure, infile.close)
if __name__ == "__main__":

View File

@ -33,6 +33,7 @@ from bson.py3compat import u, StringIO, string_type
from gridfs.errors import (FileExists,
NoFile)
from test import (client_context,
client_knobs,
connection_string,
unittest,
host,
@ -77,7 +78,7 @@ class TestGridfsNoConnect(unittest.TestCase):
@classmethod
def setUpClass(cls):
client = MongoClient(host, port, _connect=False)
client = MongoClient(host, port, connect=False)
cls.db = client.pymongo_test
def test_gridfs(self):
@ -146,6 +147,24 @@ class TestGridfs(IntegrationTest):
self.assertEqual(255 * 1024, raw["chunkSize"])
self.assertTrue(isinstance(raw["md5"], string_type))
def test_delete_ensures_index(self):
chunks = self.db.fs.chunks
# setUp has dropped collections.
self.assertFalse(chunks.index_information())
self.fs.delete(file_id=1)
# delete() has ensured an index on (files_id, n).
# index_information() is like:
# {
# '_id_': {'key': [('_id', 1)]},
# 'files_id_1_n_1': {'key': [('files_id', 1), ('n', 1)]}
# }
self.assertTrue(any(
info.get('key') == [('files_id', 1), ('n', 1)]
for info in chunks.index_information().values()))
def test_alt_collection(self):
oid = self.alt.put(b"hello world")
self.assertEqual(b"hello world", self.alt.get(oid).read())
@ -375,13 +394,15 @@ class TestGridfs(IntegrationTest):
self.assertFalse(self.db.connection.in_request())
def test_gridfs_lazy_connect(self):
client = MongoClient('badhost', _connect=False)
db = client.db
self.assertRaises(ConnectionFailure, gridfs.GridFS, db)
with client_knobs(server_wait_time=0.01):
client = MongoClient('badhost', connect=False)
db = client.db
gfs = gridfs.GridFS(db)
self.assertRaises(ConnectionFailure, gfs.list)
fs = gridfs.GridFS(db, _connect=False)
f = fs.new_file() # Still no connection.
self.assertRaises(ConnectionFailure, f.close)
fs = gridfs.GridFS(db)
f = fs.new_file() # Still no connection.
self.assertRaises(ConnectionFailure, f.close)
def test_gridfs_find(self):
self.fs.put(b"test2", filename="two")
@ -448,10 +469,10 @@ class TestGridfsReplicaSet(TestReplicaSetClientBase):
client = MongoClient(
connection_string(seeds=[secondary_pair]),
read_preference=ReadPreference.SECONDARY,
_connect=False)
connect=False)
# Still no connection.
fs = gridfs.GridFS(client.test_gridfs_secondary_lazy, _connect=False)
fs = gridfs.GridFS(client.test_gridfs_secondary_lazy)
# Connects, doesn't create index.
self.assertRaises(NoFile, fs.get_last_version)

View File

@ -20,8 +20,9 @@ import threading
sys.path[0:0] = [""]
from pymongo.errors import AutoReconnect
from test import unittest, client_context
from test import unittest, client_context, SkipTest, MockClientTest
from test.pymongo_mocks import MockClient
from test.utils import connected, wait_until
@client_context.require_connection
@ -53,18 +54,22 @@ def do_simple_op(client, nthreads):
assert t.passed
class TestMongosHA(unittest.TestCase):
def mock_client(self, connect):
class TestMongosHA(MockClientTest):
def mock_client(self):
return MockClient(
standalones=[],
members=[],
mongoses=['a:1', 'b:2', 'c:3'],
host='a:1,b:2,c:3',
_connect=connect)
connect=False)
def test_lazy_connect(self):
# TODO: Reimplement Mongos HA with PyMongo 3's MongoClient.
raise SkipTest('Mongos HA must be reimplemented with 3.0 MongoClient')
nthreads = 10
client = self.mock_client(False)
client = self.mock_client()
self.assertEqual(0, len(client.nodes))
# Trigger initial connection.
@ -73,15 +78,24 @@ class TestMongosHA(unittest.TestCase):
def test_reconnect(self):
nthreads = 10
client = self.mock_client(True)
self.assertEqual(3, len(client.nodes))
client = connected(self.mock_client())
# connected() ensures we've contacted at least one mongos. Wait for
# all of them.
wait_until(lambda: len(client.nodes) == 3, 'connect to all mongoses')
# Trigger reconnect.
client.disconnect()
do_simple_op(client, nthreads)
self.assertEqual(3, len(client.nodes))
wait_until(lambda: len(client.nodes) == 3,
'reconnect to all mongoses')
def test_failover(self):
# TODO: PyMongo 3's MongoClient currently picks a new Mongos at random
# for each operation (besides getMore). Need to "pin".
raise SkipTest('Mongos HA must be reimplemented with 3.0 MongoClient')
nthreads = 1
# ['1:1', '2:2', '3:3', ...]

View File

@ -24,6 +24,7 @@ import time
from pymongo import MongoClient
from pymongo.errors import ConfigurationError, ConnectionFailure, \
ExceededMaxWaiters
from pymongo.server_selectors import writable_server_selector
sys.path[0:0] = [""]
@ -231,9 +232,16 @@ class _TestPoolingBase(unittest.TestCase):
return Pool(pair, PoolOptions(*args, **kwargs))
def assert_no_request(self):
self.assertTrue(
self.c._MongoClient__member is None or
NO_REQUEST == get_pool(self.c)._get_request_state())
try:
server = self.c._cluster.select_server(
writable_server_selector,
server_wait_time=0)
self.assertEqual(NO_REQUEST, server.pool._get_request_state())
except ConnectionFailure:
# Success: we're asserting that we're not in a request, but there's
# no pool at all so the assertion is true.
pass
def assert_request_without_socket(self):
self.assertEqual(NO_SOCKET_YET, get_pool(self.c)._get_request_state())
@ -244,9 +252,16 @@ class _TestPoolingBase(unittest.TestCase):
def assert_pool_size(self, pool_size):
if pool_size == 0:
self.assertTrue(
self.c._MongoClient__member is None
or not get_pool(self.c).sockets)
try:
server = self.c._cluster.select_server(
writable_server_selector,
server_wait_time=0)
self.assertEqual(0, len(server.pool.sockets))
except ConnectionFailure:
# Success: we're asserting that pool size is 0, and there's no
# pool at all so the assertion is true.
pass
else:
self.assertEqual(pool_size, len(get_pool(self.c).sockets))
@ -634,7 +649,6 @@ class TestPooling(_TestPoolingBase):
def loop(pipe):
c = get_client()
self.assertEqual(1, len(get_pool(c).sockets))
c.pymongo_test.test.find_one()
self.assertEqual(1, len(get_pool(c).sockets))
pipe.send(one(get_pool(c).sockets).sock.getsockname())

View File

@ -41,6 +41,7 @@ from test import (client_context,
IntegrationTest,
db_user,
db_pwd)
from test.utils import connected
from test.version import Version
@ -551,8 +552,8 @@ class TestMongosConnection(IntegrationTest):
None, [{}]
):
# Create a client e.g. with read_preference=NEAREST
c = get_client(host, port,
read_preference=mode(tag_sets=tag_sets))
c = connected(get_client(
host, port, read_preference=mode(tag_sets=tag_sets)))
self.assertEqual(is_mongos, c.is_mongos)
cursor = c.pymongo_test.test.find()
@ -591,8 +592,8 @@ class TestMongosConnection(IntegrationTest):
[{'dc': 'la'}, {'dc': 'sf'}],
[{'dc': 'la'}, {'dc': 'sf'}, {}],
):
c = get_client(host, port,
read_preference=mode(tag_sets=tag_sets))
c = connected(get_client(
host, port, read_preference=mode(tag_sets=tag_sets)))
self.assertEqual(is_mongos, c.is_mongos)
cursor = c.pymongo_test.test.find()

View File

@ -47,13 +47,14 @@ from test import (client_context,
SkipTest,
unittest,
db_pwd,
db_user)
from test.pymongo_mocks import MockReplicaSetClient
db_user,
MockClientTest)
from test.pymongo_mocks import MockClient
from test.utils import (
delay, assertReadFrom, assertReadFromAll, read_from_which_host,
remove_all_users, assertRaisesExactly, TestRequestMixin, one,
server_started_with_auth, pools_from_rs_client, get_pool,
get_rs_client, _TestLazyConnectMixin)
get_rs_client, _TestLazyConnectMixin, connected, wait_until)
from test.version import Version
@ -77,6 +78,7 @@ class TestReplicaSetClientBase(unittest.TestCase):
@classmethod
@client_context.require_replica_set
def setUpClass(cls):
raise SkipTest('Replica set tests must be updated for 3.0 MongoClient')
cls.name = client_context.setname
ismaster = client_context.ismaster
cls.w = client_context.w
@ -120,7 +122,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self.fail(self._formatMessage(msg, standardMsg))
def test_init_disconnected(self):
c = self._get_client(_connect=False)
c = self._get_client(connect=False)
self.assertIsInstance(c.is_mongos, bool)
self.assertIsInstance(c.max_pool_size, int)
@ -146,28 +148,28 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self.assertEqual(c.max_wire_version, 0)
self.assertTrue(c.min_wire_version >= 0)
c = self._get_client(_connect=False)
c = self._get_client(connect=False)
c.pymongo_test.test.update({}, {}) # Auto-connect for write.
self.assertTrue(c.primary)
c = self._get_client(_connect=False)
c = self._get_client(connect=False)
c.pymongo_test.test.insert({}) # Auto-connect for write.
self.assertTrue(c.primary)
c = self._get_client(_connect=False)
c = self._get_client(connect=False)
c.pymongo_test.test.remove({}) # Auto-connect for write.
self.assertTrue(c.primary)
c = MongoReplicaSetClient(
"somedomainthatdoesntexist.org", replicaSet="rs",
connectTimeoutMS=1, _connect=False)
connectTimeoutMS=1, connect=False)
self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one)
def test_init_disconnected_with_auth_failure(self):
c = MongoReplicaSetClient(
"mongodb://user:pass@somedomainthatdoesntexist", replicaSet="rs",
connectTimeoutMS=1, _connect=False)
connectTimeoutMS=1, connect=False)
self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one)
@ -183,14 +185,14 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
uri = "mongodb://user:pass@%s:%d/pymongo_test?replicaSet=%s" % (
host[0], host[1], self.name)
authenticated_client = MongoReplicaSetClient(uri, _connect=False)
authenticated_client = MongoReplicaSetClient(uri, connect=False)
authenticated_client.pymongo_test.test.find_one()
# Wrong password.
bad_uri = "mongodb://user:wrong@%s:%d/pymongo_test?replicaSet=%s" % (
host[0], host[1], self.name)
bad_client = MongoReplicaSetClient(bad_uri, _connect=False)
bad_client = MongoReplicaSetClient(bad_uri, connect=False)
self.assertRaises(
OperationFailure, bad_client.pymongo_test.test.find_one)
@ -310,7 +312,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
lazy_client = MongoReplicaSetClient(
"mongodb://user:wrong@%s/pymongo_test" % pair,
replicaSet=self.name,
_connect=False)
connect=False)
assertRaisesExactly(
OperationFailure, lazy_client.test.collection.find_one)
@ -453,7 +455,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
uri = "mongodb://%s:%d/foo?replicaSet=%s" % (
host[0], host[1], self.name)
c = MongoReplicaSetClient(uri, _connect=False)
c = MongoReplicaSetClient(uri, connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
def test_get_default_database_error(self):
@ -462,7 +464,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
uri = "mongodb://%s:%d/?replicaSet=%s" % (
host[0], host[1], self.name)
c = MongoReplicaSetClient(uri, _connect=False)
c = MongoReplicaSetClient(uri, connect=False)
self.assertRaises(ConfigurationError, c.get_default_database)
def test_get_default_database_with_authsource(self):
@ -471,7 +473,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
uri = "mongodb://%s:%d/foo?replicaSet=%s&authSource=src" % (
host[0], host[1], self.name)
c = MongoReplicaSetClient(uri, _connect=False)
c = MongoReplicaSetClient(uri, connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
def test_iteration(self):
@ -1062,7 +1064,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
self.assertTrue(client.alive())
client = MongoReplicaSetClient(
'doesnt exist', replicaSet='rs', _connect=False)
'doesnt exist', replicaSet='rs', connect=False)
self.assertFalse(client.alive())
@ -1115,17 +1117,17 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin):
c.test.collection.find_one()
class TestReplicaSetWireVersion(unittest.TestCase):
class TestReplicaSetWireVersion(MockClientTest):
@client_context.require_connection
def test_wire_version(self):
c = MockReplicaSetClient(
c = MockClient(
standalones=[],
members=['a:1', 'b:2', 'c:3'],
mongoses=[],
host='a:1',
replicaSet='rs',
_connect=False)
connect=False)
c.set_wire_version_range('a:1', 1, 5)
c.set_wire_version_range('b:2', 0, 1)
@ -1135,16 +1137,22 @@ class TestReplicaSetWireVersion(unittest.TestCase):
self.assertEqual(c.max_wire_version, 5)
c.set_wire_version_range('a:1', 2, 2)
c.refresh()
self.assertEqual(c.min_wire_version, 2)
wait_until(lambda: c.min_wire_version == 2, 'update min_wire_version')
self.assertEqual(c.max_wire_version, 2)
# A secondary doesn't overlap with us.
c.set_wire_version_range('b:2', 5, 6)
# refresh() raises, as do all following operations.
self.assertRaises(ConfigurationError, c.refresh)
self.assertRaises(ConfigurationError, c.db.collection.find_one)
def raises_configuration_error():
try:
c.db.collection.find_one()
return False
except ConfigurationError:
return True
wait_until(raises_configuration_error,
'notice we are incompatible with server')
self.assertRaises(ConfigurationError, c.db.collection.insert, {})
@ -1159,7 +1167,7 @@ class TestReplicaSetClientLazyConnect(
def test_read_mode_secondary(self):
client = MongoReplicaSetClient(
connection_string(), replicaSet=self.name, _connect=False,
connection_string(), replicaSet=self.name, connect=False,
read_preference=ReadPreference.SECONDARY)
# No error.
@ -1188,34 +1196,35 @@ class TestReplicaSetClientLazyConnectBadSeeds(
return client
class TestReplicaSetClientInternalIPs(unittest.TestCase):
class TestReplicaSetClientInternalIPs(MockClientTest):
@client_context.require_connection
def test_connect_with_internal_ips(self):
# Client is passed an IP it can reach, 'a:1', but the RS config
# only contains unreachable IPs like 'internal-ip'. PYTHON-608.
assertRaisesExactly(
ConnectionFailure,
MockReplicaSetClient,
standalones=[],
members=['a:1'],
mongoses=[],
ismaster_hosts=['internal-ip:27017'],
host='a:1',
replicaSet='rs')
AutoReconnect,
connected,
MockClient(
standalones=[],
members=['a:1'],
mongoses=[],
ismaster_hosts=['internal-ip:27017'],
host='a:1',
replicaSet='rs'))
class TestReplicaSetClientMaxWriteBatchSize(unittest.TestCase):
class TestReplicaSetClientMaxWriteBatchSize(MockClientTest):
@client_context.require_connection
def test_max_write_batch_size(self):
c = MockReplicaSetClient(
c = MockClient(
standalones=[],
members=['a:1', 'b:2'],
mongoses=[],
host='a:1',
replicaSet='rs',
_connect=False)
connect=False)
c.set_max_write_batch_size('a:1', 1)
c.set_max_write_batch_size('b:2', 2)
@ -1229,8 +1238,8 @@ class TestReplicaSetClientMaxWriteBatchSize(unittest.TestCase):
# b becomes primary.
c.mock_primary = 'b:2'
c.refresh()
self.assertEqual(c.max_write_batch_size, 2)
wait_until(lambda: c.max_write_batch_size == 2,
'update max_write_batch_size')
if __name__ == "__main__":

View File

@ -18,10 +18,11 @@ import sys
sys.path[0:0] = [""]
from pymongo.errors import ConfigurationError, ConnectionFailure
from pymongo.errors import ConnectionFailure, AutoReconnect
from pymongo import ReadPreference
from test import unittest, client_context
from test.pymongo_mocks import MockClient, MockReplicaSetClient
from test import unittest, client_context, client_knobs, MockClientTest
from test.pymongo_mocks import MockClient
from test.utils import wait_until
@client_context.require_connection
@ -29,7 +30,7 @@ def setUpModule():
pass
class TestSecondaryBecomesStandalone(unittest.TestCase):
class TestSecondaryBecomesStandalone(MockClientTest):
# An administrator removes a secondary from a 3-node set and
# brings it back up as standalone, without updating the other
# members' config. Verify we don't continue using it.
@ -42,7 +43,7 @@ class TestSecondaryBecomesStandalone(unittest.TestCase):
replicaSet='rs')
# MongoClient connects to primary by default.
self.assertEqual('a', c.host)
wait_until(lambda: c.host == 'a', 'connect to primary')
self.assertEqual(1, c.port)
# C is brought up as a standalone.
@ -56,79 +57,82 @@ class TestSecondaryBecomesStandalone(unittest.TestCase):
# Force reconnect.
c.disconnect()
try:
with self.assertRaises(AutoReconnect):
c.db.command('ismaster')
except ConfigurationError as e:
self.assertTrue('not a member of replica set' in str(e))
else:
self.fail("MongoClient didn't raise AutoReconnect")
self.assertEqual(None, c.host)
self.assertEqual(None, c.port)
def test_replica_set_client(self):
c = MockReplicaSetClient(
c = MockClient(
standalones=[],
members=['a:1', 'b:2', 'c:3'],
mongoses=[],
host='a:1,b:2,c:3',
replicaSet='rs')
self.assertTrue(('b', 2) in c.secondaries)
self.assertTrue(('c', 3) in c.secondaries)
wait_until(lambda: ('b', 2) in c.secondaries,
'discover host "b"')
wait_until(lambda: ('c', 3) in c.secondaries,
'discover host "c"')
# C is brought up as a standalone.
c.mock_members.remove('c:3')
c.mock_standalones.append('c:3')
c.refresh()
wait_until(lambda: set([('b', 2)]) == c.secondaries,
'update the list of secondaries')
self.assertEqual(('a', 1), c.primary)
self.assertEqual(set([('b', 2)]), c.secondaries)
class TestSecondaryRemoved(unittest.TestCase):
class TestSecondaryRemoved(MockClientTest):
# An administrator removes a secondary from a 3-node set *without*
# restarting it as standalone.
def test_replica_set_client(self):
c = MockReplicaSetClient(
c = MockClient(
standalones=[],
members=['a:1', 'b:2', 'c:3'],
mongoses=[],
host='a:1,b:2,c:3',
replicaSet='rs')
self.assertTrue(('b', 2) in c.secondaries)
self.assertTrue(('c', 3) in c.secondaries)
wait_until(lambda: ('b', 2) in c.secondaries, 'discover host "b"')
wait_until(lambda: ('c', 3) in c.secondaries, 'discover host "c"')
# C is removed.
c.mock_ismaster_hosts.remove('c:3')
c.refresh()
wait_until(lambda: set([('b', 2)]) == c.secondaries,
'update list of secondaries')
self.assertEqual(('a', 1), c.primary)
self.assertEqual(set([('b', 2)]), c.secondaries)
class TestSocketError(unittest.TestCase):
class TestSocketError(MockClientTest):
def test_socket_error_marks_member_down(self):
c = MockReplicaSetClient(
standalones=[],
members=['a:1', 'b:2'],
mongoses=[],
host='a:1',
replicaSet='rs')
# Disable background refresh.
with client_knobs(heartbeat_frequency=9999999):
c = MockClient(
standalones=[],
members=['a:1', 'b:2'],
mongoses=[],
host='a:1',
replicaSet='rs')
self.assertEqual(2, len(c._MongoReplicaSetClient__rs_state.members))
wait_until(lambda: len(c.nodes) == 2, 'discover both nodes')
# b now raises socket.error.
c.mock_down_hosts.append('b:2')
self.assertRaises(
ConnectionFailure,
c.db.collection.find_one, read_preference=ReadPreference.SECONDARY)
# b now raises socket.error.
c.mock_down_hosts.append('b:2')
self.assertRaises(
ConnectionFailure,
c.db.collection.find_one,
read_preference=ReadPreference.SECONDARY)
self.assertEqual(1, len(c._MongoReplicaSetClient__rs_state.members))
self.assertEqual(1, len(c.nodes))
class TestSecondaryAdded(unittest.TestCase):
class TestSecondaryAdded(MockClientTest):
def test_client(self):
c = MockClient(
standalones=[],
@ -137,6 +141,8 @@ class TestSecondaryAdded(unittest.TestCase):
host='a:1',
replicaSet='rs')
wait_until(lambda: len(c.nodes) == 2, 'discover both nodes')
# MongoClient connects to primary by default.
self.assertEqual('a', c.host)
self.assertEqual(1, c.port)
@ -151,26 +157,30 @@ class TestSecondaryAdded(unittest.TestCase):
self.assertEqual('a', c.host)
self.assertEqual(1, c.port)
self.assertEqual(set([('a', 1), ('b', 2), ('c', 3)]), c.nodes)
wait_until(lambda: set([('a', 1), ('b', 2), ('c', 3)]) == c.nodes,
'reconnect to both secondaries')
def test_replica_set_client(self):
c = MockReplicaSetClient(
c = MockClient(
standalones=[],
members=['a:1', 'b:2'],
mongoses=[],
host='a:1',
replicaSet='rs')
self.assertEqual(('a', 1), c.primary)
self.assertEqual(set([('b', 2)]), c.secondaries)
wait_until(lambda: ('a', 1) == c.primary, 'discover the primary')
wait_until(lambda: set([('b', 2)]) == c.secondaries,
'discover the secondary')
# C is added.
c.mock_members.append('c:3')
c.mock_ismaster_hosts.append('c:3')
c.refresh()
wait_until(lambda: set([('b', 2), ('c', 3)]) == c.secondaries,
'discover the new secondary')
self.assertEqual(('a', 1), c.primary)
self.assertEqual(set([('b', 2), ('c', 3)]), c.secondaries)
if __name__ == "__main__":

View File

@ -31,7 +31,7 @@ class TestSONManipulator(unittest.TestCase):
@classmethod
def setUpClass(cls):
client = MongoClient(host, port, _connect=False)
client = MongoClient(host, port, connect=False)
cls.db = client.pymongo_test
def test_basic(self):

View File

@ -38,7 +38,7 @@ from pymongo.errors import (ConfigurationError,
OperationFailure)
from pymongo.ssl_support import HAVE_SSL
from test import host, pair, port, SkipTest, unittest
from test.utils import server_started_with_auth, remove_all_users
from test.utils import server_started_with_auth, remove_all_users, connected
from test.version import Version
@ -86,7 +86,7 @@ if HAVE_SSL:
# Is MongoDB configured for SSL?
try:
MongoClient(host, port, connectTimeoutMS=100, ssl=True)
connected(MongoClient(host, port, connectTimeoutMS=100, ssl=True))
SIMPLE_SSL = True
except ConnectionFailure:
pass
@ -94,8 +94,8 @@ if HAVE_SSL:
# Is MongoDB configured with server.pem, ca.pem, and crl.pem from
# mongodb jstests/lib?
try:
ssl_client = MongoClient(host, port, connectTimeoutMS=100, ssl=True,
ssl_certfile=CLIENT_PEM)
ssl_client = connected(MongoClient(host, port, connectTimeoutMS=100,
ssl=True, ssl_certfile=CLIENT_PEM))
CERT_SSL = True
except ConnectionFailure:
pass

View File

@ -140,9 +140,11 @@ class FindPauseFind(RendezvousThread):
def before_rendezvous(self):
# acquire a socket
client = self.collection.database.connection
client.start_request()
list(self.collection.find())
pool = get_pool(self.collection.database.connection)
pool = get_pool(client)
socket_info = pool._get_request_state()
assert isinstance(socket_info, SocketInfo)
self.request_sock = socket_info.sock
@ -241,7 +243,7 @@ class BaseTestThreads(object):
# PYTHON-345, we need to make sure that threads' request sockets are
# closed by disconnect().
#
# 1. Create a client with auto_start_request=True
# 1. Create a client and start a request.
# 2. Start N threads and do a find() in each to get a request socket
# 3. Pause all threads
# 4. In the main thread close all sockets, including threads' request
@ -252,7 +254,8 @@ class BaseTestThreads(object):
#
# If we've fixed PYTHON-345, then only one AutoReconnect is raised,
# and all the threads get new request sockets.
cx = get_client(pair, auto_start_request=True)
cx = get_client(pair)
cx.start_request()
collection = cx.db.pymongo_test
# acquire a request socket for the main thread

View File

@ -19,10 +19,12 @@ import os
import struct
import sys
import threading
import time
from pymongo import MongoClient, MongoReplicaSetClient
from pymongo.errors import AutoReconnect
from pymongo.pool import NO_REQUEST, NO_SOCKET_YET, SocketInfo
from pymongo.server_selectors import writable_server_selector
from test import (client_context,
db_user,
db_pwd)
@ -31,7 +33,7 @@ from test.version import Version
def get_client(*args, **kwargs):
client = MongoClient(*args, **kwargs)
if client_context.auth_enabled and kwargs.get("_connect", True):
if client_context.auth_enabled and kwargs.get("connect", True):
client.admin.authenticate(db_user, db_pwd)
return client
@ -147,6 +149,27 @@ def joinall(threads):
t.join(300)
assert not t.isAlive(), "Thread %s hung" % t
def connected(client):
"""Convenience to wait for a newly-constructed client to connect."""
client.admin.command('ismaster') # Force connection.
return client
def wait_until(predicate, success_description):
"""Wait up to 10 seconds for predicate to be True.
E.g.:
wait_until(lambda: c.primary == ('a', 1),
'connect to the primary')
If the lambda-expression isn't true after 10 seconds, we raise
AssertionError("Didn't ever connect to the primary").
"""
start = time.time()
while not predicate():
if time.time() - start > 10:
raise AssertionError("Didn't ever %s" % success_description)
def is_mongos(client):
res = client.admin.command('ismaster')
return res.get('msg', '') == 'isdbgrid'
@ -343,13 +366,9 @@ def assertReadFromAll(testcase, rsc, members, *args, **kwargs):
testcase.assertEqual(members, used)
def get_pool(client):
if isinstance(client, MongoClient):
return client._MongoClient__member.pool
elif isinstance(client, MongoReplicaSetClient):
rs_state = client._MongoReplicaSetClient__rs_state
return rs_state.primary_member.pool
else:
raise TypeError(str(client))
cluster = client._get_cluster()
server = cluster.select_server(writable_server_selector)
return server.pool
def pools_from_rs_client(client):
"""Get Pool instances from a MongoReplicaSetClient.
@ -451,7 +470,7 @@ def lazy_client_trial(reset, target, test, get_client):
try:
for i in range(NTRIALS):
reset(collection)
lazy_client = get_client(_connect=False)
lazy_client = get_client(connect=False)
lazy_collection = lazy_client.pymongo_test.test
run_threads(lazy_collection, target)
test(lazy_collection)
@ -469,7 +488,7 @@ class _TestLazyConnectMixin(object):
Inherit from this class and from unittest.TestCase, and override
_get_client(self, **kwargs), for testing a lazily-connecting
client, i.e. a client initialized with _connect=False.
client, i.e. a client initialized with connect=False.
"""
NTRIALS = 5
NTHREADS = 10
@ -544,7 +563,7 @@ class _TestLazyConnectMixin(object):
def test_max_bson_size(self):
# Client should have sane defaults before connecting, and should update
# its configuration once connected.
c = self._get_client(_connect=False)
c = self._get_client(connect=False)
self.assertEqual(16 * (1024 ** 2), c.max_bson_size)
self.assertEqual(2 * c.max_bson_size, c.max_message_size)