PYTHON-764 SCRAM-SHA-1 automatic upgrade / downgrade.
This commit is contained in:
parent
4b70b2a6f7
commit
ee11436675
@ -36,7 +36,7 @@ from pymongo.errors import ConfigurationError, OperationFailure
|
||||
|
||||
|
||||
MECHANISMS = frozenset(
|
||||
['GSSAPI', 'MONGODB-CR', 'MONGODB-X509', 'PLAIN', 'SCRAM-SHA-1'])
|
||||
['GSSAPI', 'MONGODB-CR', 'MONGODB-X509', 'PLAIN', 'SCRAM-SHA-1', 'DEFAULT'])
|
||||
"""The authentication mechanisms supported by PyMongo."""
|
||||
|
||||
|
||||
@ -337,6 +337,13 @@ def _authenticate_mongo_cr(credentials, sock_info):
|
||||
sock_info.command(source, query)
|
||||
|
||||
|
||||
def _authenticate_default(credentials, sock_info):
|
||||
if sock_info.max_wire_version >= 3:
|
||||
return _authenticate_scram_sha1(credentials, sock_info)
|
||||
else:
|
||||
return _authenticate_mongo_cr(credentials, sock_info)
|
||||
|
||||
|
||||
_AUTH_MAP = {
|
||||
'CRAM-MD5': _authenticate_cram_md5,
|
||||
'GSSAPI': _authenticate_gssapi,
|
||||
@ -344,6 +351,7 @@ _AUTH_MAP = {
|
||||
'MONGODB-X509': _authenticate_x509,
|
||||
'PLAIN': _authenticate_plain,
|
||||
'SCRAM-SHA-1': _authenticate_scram_sha1,
|
||||
'DEFAULT': _authenticate_default,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ def _parse_credentials(username, password, database, options):
|
||||
"""Parse authentication credentials."""
|
||||
if username is None:
|
||||
return None
|
||||
mechanism = options.get('authmechanism', 'MONGODB-CR')
|
||||
mechanism = options.get('authmechanism', 'DEFAULT')
|
||||
source = options.get('authsource', database or 'admin')
|
||||
return _build_credentials_tuple(
|
||||
mechanism, source, _unicode(username), _unicode(password), options)
|
||||
|
||||
@ -851,7 +851,7 @@ class Database(common.BaseObject):
|
||||
raise
|
||||
|
||||
def authenticate(self, name, password=None,
|
||||
source=None, mechanism='MONGODB-CR', **kwargs):
|
||||
source=None, mechanism='DEFAULT', **kwargs):
|
||||
"""Authenticate to use this database.
|
||||
|
||||
Authentication lasts for the life of the underlying client
|
||||
@ -883,11 +883,15 @@ class Database(common.BaseObject):
|
||||
specified the current database is used.
|
||||
- `mechanism` (optional): See
|
||||
:data:`~pymongo.auth.MECHANISMS` for options.
|
||||
Defaults to MONGODB-CR (MongoDB Challenge Response protocol)
|
||||
By default, use SCRAM-SHA-1 with MongoDB 2.8 and later,
|
||||
MONGODB-CR (MongoDB Challenge Response protocol) for older servers.
|
||||
- `gssapiServiceName` (optional): Used with the GSSAPI mechanism
|
||||
to specify the service name portion of the service principal name.
|
||||
Defaults to 'mongodb'.
|
||||
|
||||
.. versionadded:: 2.8
|
||||
Use SCRAM-SHA-1 with MongoDB 2.8 and later.
|
||||
|
||||
.. versionchanged:: 2.5
|
||||
Added the `source` and `mechanism` parameters. :meth:`authenticate`
|
||||
now raises a subclass of :class:`~pymongo.errors.PyMongoError` if
|
||||
|
||||
@ -336,7 +336,7 @@ class MongoClient(common.BaseObject):
|
||||
|
||||
# get_socket() logs out of the database if logged in with old
|
||||
# credentials, and logs in with new ones.
|
||||
with server.pool.get_socket(all_credentials) as sock_info:
|
||||
with server.get_socket(all_credentials) as sock_info:
|
||||
sock_info.authenticate(credentials)
|
||||
|
||||
# If several threads run _cache_credentials at once, last one wins.
|
||||
@ -652,7 +652,7 @@ class MongoClient(common.BaseObject):
|
||||
|
||||
# Avoid race when other threads log in or out.
|
||||
all_credentials = self.__all_credentials.copy()
|
||||
with server.pool.get_socket(all_credentials) as sock_info:
|
||||
with server.get_socket(all_credentials) as sock_info:
|
||||
return not pool._closed(sock_info.sock)
|
||||
|
||||
except (socket.error, ConnectionFailure):
|
||||
@ -1095,7 +1095,7 @@ class MongoClient(common.BaseObject):
|
||||
|
||||
# Avoid race when other threads log in or out.
|
||||
all_credentials = self.__all_credentials.copy()
|
||||
with server.pool.get_socket(all_credentials) as sock:
|
||||
with server.get_socket(all_credentials) as sock:
|
||||
if username is not None:
|
||||
get_nonce_cmd = SON([("copydbgetnonce", 1),
|
||||
("fromhost", from_host)])
|
||||
|
||||
@ -138,7 +138,7 @@ class Monitor(object):
|
||||
Returns a ServerDescription, or None on error.
|
||||
"""
|
||||
try:
|
||||
with self._pool.get_socket(all_credentials={}) as sock_info:
|
||||
with self._pool.get_socket({}, 0, 0) as sock_info:
|
||||
response, round_trip_time = self._check_with_socket(sock_info)
|
||||
old_rtts = self._server_description.round_trip_times
|
||||
if old_rtts:
|
||||
|
||||
@ -150,6 +150,9 @@ class SocketInfo(object):
|
||||
self.last_checkout = time.time()
|
||||
self.pool_ref = weakref.ref(pool)
|
||||
|
||||
self._min_wire_version = None
|
||||
self._max_wire_version = None
|
||||
|
||||
# The pool's pool_id changes with each reset() so we can close sockets
|
||||
# created before the last reset.
|
||||
self.pool_id = pool.pool_id
|
||||
@ -257,6 +260,20 @@ class SocketInfo(object):
|
||||
self.sock.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
def set_wire_version_range(self, min_wire_version, max_wire_version):
|
||||
self._min_wire_version = min_wire_version
|
||||
self._max_wire_version = max_wire_version
|
||||
|
||||
@property
|
||||
def min_wire_version(self):
|
||||
assert self._min_wire_version is not None
|
||||
return self._min_wire_version
|
||||
|
||||
@property
|
||||
def max_wire_version(self):
|
||||
assert self._max_wire_version is not None
|
||||
return self._max_wire_version
|
||||
|
||||
def exhaust(self, exhaust):
|
||||
self._exhaust = exhaust
|
||||
@ -425,19 +442,25 @@ class Pool:
|
||||
sock.settimeout(self.opts.socket_timeout)
|
||||
return SocketInfo(sock, self, hostname)
|
||||
|
||||
def get_socket(self, all_credentials):
|
||||
def get_socket(self, all_credentials, min_wire_version, max_wire_version):
|
||||
"""Get a socket from the pool.
|
||||
|
||||
Returns a :class:`SocketInfo` object wrapping a connected
|
||||
:class:`socket.socket`, and a bool saying whether the socket was from
|
||||
the pool or freshly created.
|
||||
:class:`socket.socket`.
|
||||
|
||||
The socket is logged in or out as necessary to match ``all_credentials``
|
||||
using the correct authentication mechanism for the server's wire
|
||||
protocol version.
|
||||
|
||||
:Parameters:
|
||||
- `all_credentials`: dict, maps auth source to MongoCredential.
|
||||
- `min_wire_version`: int, minimum protocol the server supports.
|
||||
- `max_wire_version`: int, maximum protocol the server supports.
|
||||
"""
|
||||
# First get a socket, then attempt authentication. Simplifies
|
||||
# semaphore management in the face of network errors during auth.
|
||||
sock_info = self._get_socket_no_auth()
|
||||
sock_info.set_wire_version_range(min_wire_version, max_wire_version)
|
||||
try:
|
||||
sock_info.check_auth(all_credentials)
|
||||
return sock_info
|
||||
|
||||
@ -60,7 +60,7 @@ class Server(object):
|
||||
"""
|
||||
request_id, data = self._check_bson_size(message)
|
||||
try:
|
||||
with self._pool.get_socket(all_credentials) as sock_info:
|
||||
with self.get_socket(all_credentials) as sock_info:
|
||||
sock_info.send_message(data)
|
||||
except socket.error as exc:
|
||||
self._raise_connection_failure(exc)
|
||||
@ -82,7 +82,7 @@ class Server(object):
|
||||
"""
|
||||
request_id, data = self._check_bson_size(message)
|
||||
try:
|
||||
with self._pool.get_socket(all_credentials) as sock_info:
|
||||
with self.get_socket(all_credentials) as sock_info:
|
||||
sock_info.exhaust(exhaust)
|
||||
sock_info.send_message(data)
|
||||
response_data = sock_info.receive_message(1, request_id)
|
||||
@ -99,6 +99,20 @@ class Server(object):
|
||||
except socket.error as exc:
|
||||
self._raise_connection_failure(exc)
|
||||
|
||||
def get_socket(self, all_credentials):
|
||||
sd = self.description
|
||||
sock_info = self.pool.get_socket(all_credentials=all_credentials,
|
||||
min_wire_version=sd.min_wire_version,
|
||||
max_wire_version=sd.max_wire_version)
|
||||
|
||||
return sock_info
|
||||
|
||||
def maybe_return_socket(self, sock_info):
|
||||
self.pool.maybe_return_socket(sock_info)
|
||||
|
||||
def discard_socket(self, sock_info):
|
||||
self.pool.discard_socket(sock_info)
|
||||
|
||||
def start_request(self):
|
||||
# TODO: Remove implicit threadlocal requests, use explicit requests.
|
||||
self.pool.start_request()
|
||||
|
||||
@ -38,7 +38,7 @@ class MockPool(Pool):
|
||||
Pool.__init__(self,
|
||||
(default_host, default_port), PoolOptions(connect_timeout=20))
|
||||
|
||||
def get_socket(self, all_credentials):
|
||||
def get_socket(self, all_credentials, min_wire_version, max_wire_version):
|
||||
client = self.client
|
||||
host_and_port = '%s:%s' % (self.mock_host, self.mock_port)
|
||||
if host_and_port in client.mock_down_hosts:
|
||||
@ -49,7 +49,9 @@ class MockPool(Pool):
|
||||
+ client.mock_members
|
||||
+ client.mock_mongoses), "bad host: %s" % host_and_port
|
||||
|
||||
sock_info = Pool.get_socket(self, all_credentials)
|
||||
sock_info = Pool.get_socket(
|
||||
self, all_credentials, min_wire_version, max_wire_version)
|
||||
|
||||
sock_info.mock_host = self.mock_host
|
||||
sock_info.mock_port = self.mock_port
|
||||
return sock_info
|
||||
|
||||
@ -30,7 +30,7 @@ from pymongo import MongoClient
|
||||
from pymongo.auth import HAVE_KERBEROS, _build_credentials_tuple
|
||||
from pymongo.errors import OperationFailure, ConfigurationError
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from test import client_context, host, port, SkipTest, unittest
|
||||
from test import client_context, host, port, SkipTest, unittest, Version
|
||||
|
||||
# YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS.
|
||||
GSSAPI_HOST = os.environ.get('GSSAPI_HOST')
|
||||
@ -256,11 +256,13 @@ class TestSCRAMSHA1(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.set_name = client_context.setname
|
||||
|
||||
cmd_line = client_context.cmd_line
|
||||
if 'SCRAM-SHA-1' not in cmd_line.get(
|
||||
'parsed', {}).get('setParameter',
|
||||
{}).get('authenticationMechanisms', ''):
|
||||
raise SkipTest('SCRAM-SHA-1 mechanism not enabled')
|
||||
# Before 2.7.7, SCRAM-SHA-1 had to be enabled from the command line.
|
||||
if client_context.version < Version(2, 7, 7):
|
||||
cmd_line = client_context.cmd_line
|
||||
if 'SCRAM-SHA-1' not in cmd_line.get(
|
||||
'parsed', {}).get('setParameter',
|
||||
{}).get('authenticationMechanisms', ''):
|
||||
raise SkipTest('SCRAM-SHA-1 mechanism not enabled')
|
||||
|
||||
client = client_context.rs_or_standalone_client
|
||||
client.pymongo_test.add_user(
|
||||
@ -294,9 +296,7 @@ class TestSCRAMSHA1(unittest.TestCase):
|
||||
client.pymongo_test.command('dbstats')
|
||||
|
||||
def tearDown(self):
|
||||
client_context.rs_or_standalone_client.pymongo_test.remove_user(
|
||||
'user',
|
||||
w=client_context.w)
|
||||
client_context.rs_or_standalone_client.pymongo_test.remove_user('user')
|
||||
|
||||
|
||||
class TestAuthURIOptions(unittest.TestCase):
|
||||
|
||||
@ -916,7 +916,7 @@ class TestClient(IntegrationTest, TestRequestMixin):
|
||||
|
||||
# Simulate an authenticate() call on a different socket.
|
||||
credentials = auth._build_credentials_tuple(
|
||||
'MONGODB-CR', 'admin', db_user, db_pwd, {})
|
||||
'DEFAULT', 'admin', db_user, db_pwd, {})
|
||||
|
||||
c._cache_credentials('test', credentials, connect=False)
|
||||
|
||||
|
||||
@ -53,12 +53,6 @@ class MockPool(object):
|
||||
self.pool_id = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_socket(self):
|
||||
return MockSocketInfo()
|
||||
|
||||
def maybe_return_socket(self, _):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
with self._lock:
|
||||
self.pool_id += 1
|
||||
|
||||
@ -129,7 +129,7 @@ class SocketGetter(MongoThread):
|
||||
|
||||
def run_mongo_thread(self):
|
||||
self.state = 'get_socket'
|
||||
self.sock = self.pool.get_socket(all_credentials={})
|
||||
self.sock = self.pool.get_socket({}, 0, 0)
|
||||
self.state = 'sock'
|
||||
|
||||
|
||||
@ -323,16 +323,16 @@ class TestPooling(_TestPoolingBase):
|
||||
connect_timeout=1000,
|
||||
socket_timeout=1000)
|
||||
|
||||
sock0 = cx_pool.get_socket(all_credentials={})
|
||||
sock1 = cx_pool.get_socket(all_credentials={})
|
||||
sock0 = cx_pool.get_socket({}, 0, 0)
|
||||
sock1 = cx_pool.get_socket({}, 0, 0)
|
||||
|
||||
self.assertNotEqual(sock0, sock1)
|
||||
|
||||
# Now in a request, we'll get the same socket both times
|
||||
cx_pool.start_request()
|
||||
|
||||
sock2 = cx_pool.get_socket(all_credentials={})
|
||||
sock3 = cx_pool.get_socket(all_credentials={})
|
||||
sock2 = cx_pool.get_socket({}, 0, 0)
|
||||
sock3 = cx_pool.get_socket({}, 0, 0)
|
||||
self.assertEqual(sock2, sock3)
|
||||
|
||||
# Pool didn't keep reference to sock0 or sock1; sock2 and 3 are new
|
||||
@ -342,8 +342,8 @@ class TestPooling(_TestPoolingBase):
|
||||
# Return the request sock to pool
|
||||
cx_pool.end_request()
|
||||
|
||||
sock4 = cx_pool.get_socket(all_credentials={})
|
||||
sock5 = cx_pool.get_socket(all_credentials={})
|
||||
sock4 = cx_pool.get_socket({}, 0, 0)
|
||||
sock5 = cx_pool.get_socket({}, 0, 0)
|
||||
|
||||
# Not in a request any more, we get different sockets
|
||||
self.assertNotEqual(sock4, sock5)
|
||||
@ -373,10 +373,10 @@ class TestPooling(_TestPoolingBase):
|
||||
# Test Pool's _check_closed() method doesn't close a healthy socket.
|
||||
cx_pool = self.create_pool(max_pool_size=10)
|
||||
cx_pool._check_interval_seconds = 0 # Always check.
|
||||
sock_info = cx_pool.get_socket(all_credentials={})
|
||||
sock_info = cx_pool.get_socket({}, 0, 0)
|
||||
cx_pool.maybe_return_socket(sock_info)
|
||||
|
||||
new_sock_info = cx_pool.get_socket(all_credentials={})
|
||||
new_sock_info = cx_pool.get_socket({}, 0, 0)
|
||||
self.assertEqual(sock_info, new_sock_info)
|
||||
cx_pool.maybe_return_socket(new_sock_info)
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
@ -387,13 +387,13 @@ class TestPooling(_TestPoolingBase):
|
||||
cx_pool = self.create_pool(max_pool_size=10)
|
||||
cx_pool._check_interval_seconds = 0 # Always check.
|
||||
|
||||
with cx_pool.get_socket(all_credentials={}) as sock_info:
|
||||
with cx_pool.get_socket({}, 0, 0) as sock_info:
|
||||
# Simulate a closed socket without telling the SocketInfo it's
|
||||
# closed.
|
||||
sock_info.sock.close()
|
||||
self.assertTrue(_closed(sock_info.sock))
|
||||
|
||||
with cx_pool.get_socket(all_credentials={}) as new_sock_info:
|
||||
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
self.assertNotEqual(sock_info, new_sock_info)
|
||||
|
||||
@ -406,14 +406,14 @@ class TestPooling(_TestPoolingBase):
|
||||
cx_pool.start_request()
|
||||
|
||||
# Get the request socket.
|
||||
with cx_pool.get_socket(all_credentials={}) as sock_info:
|
||||
with cx_pool.get_socket({}, 0, 0) as sock_info:
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
self.assertEqual(sock_info, cx_pool._get_request_state())
|
||||
sock_info.sock.close()
|
||||
|
||||
# Although the request socket died, we're still in a request with a
|
||||
# new socket.
|
||||
with cx_pool.get_socket(all_credentials={}) as new_sock_info:
|
||||
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
|
||||
self.assertTrue(cx_pool.in_request())
|
||||
self.assertNotEqual(sock_info, new_sock_info)
|
||||
self.assertEqual(new_sock_info, cx_pool._get_request_state())
|
||||
@ -430,7 +430,7 @@ class TestPooling(_TestPoolingBase):
|
||||
cx_pool.start_request()
|
||||
|
||||
# Get the request socket
|
||||
with cx_pool.get_socket(all_credentials={}) as sock_info:
|
||||
with cx_pool.get_socket({}, 0, 0) as sock_info:
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
self.assertEqual(sock_info, cx_pool._get_request_state())
|
||||
|
||||
@ -440,7 +440,7 @@ class TestPooling(_TestPoolingBase):
|
||||
|
||||
# Although the request socket died, we're still in a request with a
|
||||
# new socket
|
||||
with cx_pool.get_socket(all_credentials={}) as new_sock_info:
|
||||
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
|
||||
self.assertTrue(cx_pool.in_request())
|
||||
self.assertNotEqual(sock_info, new_sock_info)
|
||||
self.assertEqual(new_sock_info, cx_pool._get_request_state())
|
||||
@ -459,7 +459,7 @@ class TestPooling(_TestPoolingBase):
|
||||
cx_pool.start_request()
|
||||
|
||||
# Get the request socket.
|
||||
with cx_pool.get_socket(all_credentials={}) as sock_info:
|
||||
with cx_pool.get_socket({}, 0, 0) as sock_info:
|
||||
self.assertEqual(sock_info, cx_pool._get_request_state())
|
||||
|
||||
# End request.
|
||||
@ -470,7 +470,7 @@ class TestPooling(_TestPoolingBase):
|
||||
sock_info.sock.close()
|
||||
|
||||
# Dead socket detected and removed.
|
||||
with cx_pool.get_socket(all_credentials={}) as new_sock_info:
|
||||
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
|
||||
self.assertFalse(cx_pool.in_request())
|
||||
self.assertNotEqual(sock_info, new_sock_info)
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
@ -486,17 +486,17 @@ class TestPooling(_TestPoolingBase):
|
||||
cx_pool.start_request()
|
||||
|
||||
# Get and close the request socket.
|
||||
with cx_pool.get_socket(all_credentials={}) as request_sock_info:
|
||||
with cx_pool.get_socket({}, 0, 0) as request_sock_info:
|
||||
request_sock_info.sock.close()
|
||||
|
||||
# Detects closed socket and creates new one, semaphore value still 0.
|
||||
with cx_pool.get_socket(all_credentials={}) as request_sock_info_2:
|
||||
with cx_pool.get_socket({}, 0, 0) as request_sock_info_2:
|
||||
self.assertNotEqual(request_sock_info, request_sock_info_2)
|
||||
|
||||
cx_pool.end_request()
|
||||
|
||||
# Semaphore value now 1; we can get a socket.
|
||||
sock = cx_pool.get_socket(all_credentials={})
|
||||
sock = cx_pool.get_socket({}, 0, 0)
|
||||
sock.close()
|
||||
|
||||
def test_socket_reclamation(self):
|
||||
@ -519,7 +519,7 @@ class TestPooling(_TestPoolingBase):
|
||||
self.assertEqual(NO_REQUEST, cx_pool._get_request_state())
|
||||
cx_pool.start_request()
|
||||
self.assertEqual(NO_SOCKET_YET, cx_pool._get_request_state())
|
||||
with cx_pool.get_socket(all_credentials={}) as sock_info:
|
||||
with cx_pool.get_socket({}, 0, 0) as sock_info:
|
||||
self.assertEqual(sock_info, cx_pool._get_request_state())
|
||||
the_sock[0] = id(sock_info.sock)
|
||||
|
||||
@ -608,8 +608,8 @@ class TestPooling(_TestPoolingBase):
|
||||
|
||||
def run_in_request():
|
||||
p.start_request()
|
||||
sock0 = p.get_socket(all_credentials={})
|
||||
sock1 = p.get_socket(all_credentials={})
|
||||
sock0 = p.get_socket({}, 0, 0)
|
||||
sock1 = p.get_socket({}, 0, 0)
|
||||
sock_ids.extend([id(sock0), id(sock1)])
|
||||
p.maybe_return_socket(sock0)
|
||||
p.maybe_return_socket(sock1)
|
||||
@ -681,7 +681,7 @@ class TestPooling(_TestPoolingBase):
|
||||
self.assertTrue(b_sock != c_sock)
|
||||
|
||||
# a_sock, created by parent process, is still in the pool
|
||||
d_sock = get_pool(a).get_socket(all_credentials={})
|
||||
d_sock = get_pool(a).get_socket({}, 0, 0)
|
||||
self.assertEqual(a_sock, d_sock)
|
||||
d_sock.close()
|
||||
|
||||
@ -690,10 +690,10 @@ class TestPooling(_TestPoolingBase):
|
||||
pool = self.create_pool(
|
||||
max_pool_size=1, wait_queue_timeout=wait_queue_timeout)
|
||||
|
||||
sock_info = pool.get_socket(all_credentials={})
|
||||
sock_info = pool.get_socket({}, 0, 0)
|
||||
start = time.time()
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
pool.get_socket(all_credentials={})
|
||||
pool.get_socket({}, 0, 0)
|
||||
|
||||
duration = time.time() - start
|
||||
self.assertTrue(
|
||||
@ -708,7 +708,7 @@ class TestPooling(_TestPoolingBase):
|
||||
pool = self.create_pool(max_pool_size=1)
|
||||
|
||||
# Reach max_size.
|
||||
with pool.get_socket(all_credentials={}) as s1:
|
||||
with pool.get_socket({}, 0, 0) as s1:
|
||||
t = SocketGetter(self.c, pool)
|
||||
t.start()
|
||||
while t.state != 'get_socket':
|
||||
@ -730,8 +730,8 @@ class TestPooling(_TestPoolingBase):
|
||||
max_pool_size=2, wait_queue_multiple=wait_queue_multiple)
|
||||
|
||||
# Reach max_size sockets.
|
||||
socket_info_0 = pool.get_socket(all_credentials={})
|
||||
socket_info_1 = pool.get_socket(all_credentials={})
|
||||
socket_info_0 = pool.get_socket({}, 0, 0)
|
||||
socket_info_1 = pool.get_socket({}, 0, 0)
|
||||
|
||||
# Reach max_size * wait_queue_multiple waiters.
|
||||
threads = []
|
||||
@ -745,7 +745,7 @@ class TestPooling(_TestPoolingBase):
|
||||
self.assertEqual(t.state, 'get_socket')
|
||||
|
||||
with self.assertRaises(ExceededMaxWaiters):
|
||||
pool.get_socket(all_credentials={})
|
||||
pool.get_socket({}, 0, 0)
|
||||
socket_info_0.close()
|
||||
socket_info_1.close()
|
||||
|
||||
@ -754,7 +754,7 @@ class TestPooling(_TestPoolingBase):
|
||||
|
||||
socks = []
|
||||
for _ in range(2):
|
||||
sock = pool.get_socket(all_credentials={})
|
||||
sock = pool.get_socket({}, 0, 0)
|
||||
socks.append(sock)
|
||||
threads = []
|
||||
for _ in range(30):
|
||||
@ -989,7 +989,7 @@ class TestPoolMaxSize(_TestPoolingBase):
|
||||
# socket from pool" instead of the socket.error.
|
||||
for i in range(2):
|
||||
with self.assertRaises(socket.error):
|
||||
test_pool.get_socket(all_credentials={})
|
||||
test_pool.get_socket({}, 0, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -55,7 +55,7 @@ class MockPool(object):
|
||||
self.pool_id = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_socket(self, all_credentials):
|
||||
def get_socket(self, all_credentials, min_wire_version, max_wire_version):
|
||||
return MockSocketInfo()
|
||||
|
||||
def maybe_return_socket(self, _):
|
||||
|
||||
@ -132,6 +132,10 @@ class TestURI(unittest.TestCase):
|
||||
split_options('authMechanism=GSSAPI'))
|
||||
self.assertEqual({'authmechanism': 'MONGODB-CR'},
|
||||
split_options('authMechanism=MONGODB-CR'))
|
||||
self.assertEqual({'authmechanism': 'SCRAM-SHA-1'},
|
||||
split_options('authMechanism=SCRAM-SHA-1'))
|
||||
self.assertRaises(ConfigurationError,
|
||||
split_options, 'authMechanism=foo')
|
||||
self.assertEqual({'authsource': 'foobar'}, split_options('authSource=foobar'))
|
||||
# maxPoolSize isn't yet a documented URI option.
|
||||
self.assertRaises(ConfigurationError, split_options, 'maxpoolsize=50')
|
||||
|
||||
@ -413,15 +413,15 @@ class TestRequestMixin(object):
|
||||
convenient methods for testing connection pools and requests
|
||||
"""
|
||||
def assertSameSock(self, pool):
|
||||
sock_info0 = pool.get_socket(all_credentials={})
|
||||
sock_info1 = pool.get_socket(all_credentials={})
|
||||
sock_info0 = pool.get_socket({}, 0, 0)
|
||||
sock_info1 = pool.get_socket({}, 0, 0)
|
||||
self.assertEqual(sock_info0, sock_info1)
|
||||
pool.maybe_return_socket(sock_info0)
|
||||
pool.maybe_return_socket(sock_info1)
|
||||
|
||||
def assertDifferentSock(self, pool):
|
||||
sock_info0 = pool.get_socket(all_credentials={})
|
||||
sock_info1 = pool.get_socket(all_credentials={})
|
||||
sock_info0 = pool.get_socket({}, 0, 0)
|
||||
sock_info1 = pool.get_socket({}, 0, 0)
|
||||
self.assertNotEqual(sock_info0, sock_info1)
|
||||
pool.maybe_return_socket(sock_info0)
|
||||
pool.maybe_return_socket(sock_info1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user