PYTHON-764 SCRAM-SHA-1 automatic upgrade / downgrade.

This commit is contained in:
A. Jesse Jiryu Davis 2014-10-23 16:30:55 -04:00
parent 4b70b2a6f7
commit ee11436675
15 changed files with 117 additions and 68 deletions

View File

@ -36,7 +36,7 @@ from pymongo.errors import ConfigurationError, OperationFailure
MECHANISMS = frozenset(
['GSSAPI', 'MONGODB-CR', 'MONGODB-X509', 'PLAIN', 'SCRAM-SHA-1'])
['GSSAPI', 'MONGODB-CR', 'MONGODB-X509', 'PLAIN', 'SCRAM-SHA-1', 'DEFAULT'])
"""The authentication mechanisms supported by PyMongo."""
@ -337,6 +337,13 @@ def _authenticate_mongo_cr(credentials, sock_info):
sock_info.command(source, query)
def _authenticate_default(credentials, sock_info):
if sock_info.max_wire_version >= 3:
return _authenticate_scram_sha1(credentials, sock_info)
else:
return _authenticate_mongo_cr(credentials, sock_info)
_AUTH_MAP = {
'CRAM-MD5': _authenticate_cram_md5,
'GSSAPI': _authenticate_gssapi,
@ -344,6 +351,7 @@ _AUTH_MAP = {
'MONGODB-X509': _authenticate_x509,
'PLAIN': _authenticate_plain,
'SCRAM-SHA-1': _authenticate_scram_sha1,
'DEFAULT': _authenticate_default,
}

View File

@ -28,7 +28,7 @@ def _parse_credentials(username, password, database, options):
"""Parse authentication credentials."""
if username is None:
return None
mechanism = options.get('authmechanism', 'MONGODB-CR')
mechanism = options.get('authmechanism', 'DEFAULT')
source = options.get('authsource', database or 'admin')
return _build_credentials_tuple(
mechanism, source, _unicode(username), _unicode(password), options)

View File

@ -851,7 +851,7 @@ class Database(common.BaseObject):
raise
def authenticate(self, name, password=None,
source=None, mechanism='MONGODB-CR', **kwargs):
source=None, mechanism='DEFAULT', **kwargs):
"""Authenticate to use this database.
Authentication lasts for the life of the underlying client
@ -883,11 +883,15 @@ class Database(common.BaseObject):
specified the current database is used.
- `mechanism` (optional): See
:data:`~pymongo.auth.MECHANISMS` for options.
Defaults to MONGODB-CR (MongoDB Challenge Response protocol)
By default, use SCRAM-SHA-1 with MongoDB 2.8 and later,
MONGODB-CR (MongoDB Challenge Response protocol) for older servers.
- `gssapiServiceName` (optional): Used with the GSSAPI mechanism
to specify the service name portion of the service principal name.
Defaults to 'mongodb'.
.. versionadded:: 2.8
Use SCRAM-SHA-1 with MongoDB 2.8 and later.
.. versionchanged:: 2.5
Added the `source` and `mechanism` parameters. :meth:`authenticate`
now raises a subclass of :class:`~pymongo.errors.PyMongoError` if

View File

@ -336,7 +336,7 @@ class MongoClient(common.BaseObject):
# get_socket() logs out of the database if logged in with old
# credentials, and logs in with new ones.
with server.pool.get_socket(all_credentials) as sock_info:
with server.get_socket(all_credentials) as sock_info:
sock_info.authenticate(credentials)
# If several threads run _cache_credentials at once, last one wins.
@ -652,7 +652,7 @@ class MongoClient(common.BaseObject):
# Avoid race when other threads log in or out.
all_credentials = self.__all_credentials.copy()
with server.pool.get_socket(all_credentials) as sock_info:
with server.get_socket(all_credentials) as sock_info:
return not pool._closed(sock_info.sock)
except (socket.error, ConnectionFailure):
@ -1095,7 +1095,7 @@ class MongoClient(common.BaseObject):
# Avoid race when other threads log in or out.
all_credentials = self.__all_credentials.copy()
with server.pool.get_socket(all_credentials) as sock:
with server.get_socket(all_credentials) as sock:
if username is not None:
get_nonce_cmd = SON([("copydbgetnonce", 1),
("fromhost", from_host)])

View File

@ -138,7 +138,7 @@ class Monitor(object):
Returns a ServerDescription, or None on error.
"""
try:
with self._pool.get_socket(all_credentials={}) as sock_info:
with self._pool.get_socket({}, 0, 0) as sock_info:
response, round_trip_time = self._check_with_socket(sock_info)
old_rtts = self._server_description.round_trip_times
if old_rtts:

View File

@ -150,6 +150,9 @@ class SocketInfo(object):
self.last_checkout = time.time()
self.pool_ref = weakref.ref(pool)
self._min_wire_version = None
self._max_wire_version = None
# The pool's pool_id changes with each reset() so we can close sockets
# created before the last reset.
self.pool_id = pool.pool_id
@ -257,6 +260,20 @@ class SocketInfo(object):
self.sock.close()
except:
pass
def set_wire_version_range(self, min_wire_version, max_wire_version):
self._min_wire_version = min_wire_version
self._max_wire_version = max_wire_version
@property
def min_wire_version(self):
assert self._min_wire_version is not None
return self._min_wire_version
@property
def max_wire_version(self):
assert self._max_wire_version is not None
return self._max_wire_version
def exhaust(self, exhaust):
self._exhaust = exhaust
@ -425,19 +442,25 @@ class Pool:
sock.settimeout(self.opts.socket_timeout)
return SocketInfo(sock, self, hostname)
def get_socket(self, all_credentials):
def get_socket(self, all_credentials, min_wire_version, max_wire_version):
"""Get a socket from the pool.
Returns a :class:`SocketInfo` object wrapping a connected
:class:`socket.socket`, and a bool saying whether the socket was from
the pool or freshly created.
:class:`socket.socket`.
The socket is logged in or out as necessary to match ``all_credentials``
using the correct authentication mechanism for the server's wire
protocol version.
:Parameters:
- `all_credentials`: dict, maps auth source to MongoCredential.
- `min_wire_version`: int, minimum protocol the server supports.
- `max_wire_version`: int, maximum protocol the server supports.
"""
# First get a socket, then attempt authentication. Simplifies
# semaphore management in the face of network errors during auth.
sock_info = self._get_socket_no_auth()
sock_info.set_wire_version_range(min_wire_version, max_wire_version)
try:
sock_info.check_auth(all_credentials)
return sock_info

View File

@ -60,7 +60,7 @@ class Server(object):
"""
request_id, data = self._check_bson_size(message)
try:
with self._pool.get_socket(all_credentials) as sock_info:
with self.get_socket(all_credentials) as sock_info:
sock_info.send_message(data)
except socket.error as exc:
self._raise_connection_failure(exc)
@ -82,7 +82,7 @@ class Server(object):
"""
request_id, data = self._check_bson_size(message)
try:
with self._pool.get_socket(all_credentials) as sock_info:
with self.get_socket(all_credentials) as sock_info:
sock_info.exhaust(exhaust)
sock_info.send_message(data)
response_data = sock_info.receive_message(1, request_id)
@ -99,6 +99,20 @@ class Server(object):
except socket.error as exc:
self._raise_connection_failure(exc)
def get_socket(self, all_credentials):
sd = self.description
sock_info = self.pool.get_socket(all_credentials=all_credentials,
min_wire_version=sd.min_wire_version,
max_wire_version=sd.max_wire_version)
return sock_info
def maybe_return_socket(self, sock_info):
self.pool.maybe_return_socket(sock_info)
def discard_socket(self, sock_info):
self.pool.discard_socket(sock_info)
def start_request(self):
# TODO: Remove implicit threadlocal requests, use explicit requests.
self.pool.start_request()

View File

@ -38,7 +38,7 @@ class MockPool(Pool):
Pool.__init__(self,
(default_host, default_port), PoolOptions(connect_timeout=20))
def get_socket(self, all_credentials):
def get_socket(self, all_credentials, min_wire_version, max_wire_version):
client = self.client
host_and_port = '%s:%s' % (self.mock_host, self.mock_port)
if host_and_port in client.mock_down_hosts:
@ -49,7 +49,9 @@ class MockPool(Pool):
+ client.mock_members
+ client.mock_mongoses), "bad host: %s" % host_and_port
sock_info = Pool.get_socket(self, all_credentials)
sock_info = Pool.get_socket(
self, all_credentials, min_wire_version, max_wire_version)
sock_info.mock_host = self.mock_host
sock_info.mock_port = self.mock_port
return sock_info

View File

@ -30,7 +30,7 @@ from pymongo import MongoClient
from pymongo.auth import HAVE_KERBEROS, _build_credentials_tuple
from pymongo.errors import OperationFailure, ConfigurationError
from pymongo.read_preferences import ReadPreference
from test import client_context, host, port, SkipTest, unittest
from test import client_context, host, port, SkipTest, unittest, Version
# YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS.
GSSAPI_HOST = os.environ.get('GSSAPI_HOST')
@ -256,11 +256,13 @@ class TestSCRAMSHA1(unittest.TestCase):
def setUp(self):
self.set_name = client_context.setname
cmd_line = client_context.cmd_line
if 'SCRAM-SHA-1' not in cmd_line.get(
'parsed', {}).get('setParameter',
{}).get('authenticationMechanisms', ''):
raise SkipTest('SCRAM-SHA-1 mechanism not enabled')
# Before 2.7.7, SCRAM-SHA-1 had to be enabled from the command line.
if client_context.version < Version(2, 7, 7):
cmd_line = client_context.cmd_line
if 'SCRAM-SHA-1' not in cmd_line.get(
'parsed', {}).get('setParameter',
{}).get('authenticationMechanisms', ''):
raise SkipTest('SCRAM-SHA-1 mechanism not enabled')
client = client_context.rs_or_standalone_client
client.pymongo_test.add_user(
@ -294,9 +296,7 @@ class TestSCRAMSHA1(unittest.TestCase):
client.pymongo_test.command('dbstats')
def tearDown(self):
client_context.rs_or_standalone_client.pymongo_test.remove_user(
'user',
w=client_context.w)
client_context.rs_or_standalone_client.pymongo_test.remove_user('user')
class TestAuthURIOptions(unittest.TestCase):

View File

@ -916,7 +916,7 @@ class TestClient(IntegrationTest, TestRequestMixin):
# Simulate an authenticate() call on a different socket.
credentials = auth._build_credentials_tuple(
'MONGODB-CR', 'admin', db_user, db_pwd, {})
'DEFAULT', 'admin', db_user, db_pwd, {})
c._cache_credentials('test', credentials, connect=False)

View File

@ -53,12 +53,6 @@ class MockPool(object):
self.pool_id = 0
self._lock = threading.Lock()
def get_socket(self):
return MockSocketInfo()
def maybe_return_socket(self, _):
pass
def reset(self):
with self._lock:
self.pool_id += 1

View File

@ -129,7 +129,7 @@ class SocketGetter(MongoThread):
def run_mongo_thread(self):
self.state = 'get_socket'
self.sock = self.pool.get_socket(all_credentials={})
self.sock = self.pool.get_socket({}, 0, 0)
self.state = 'sock'
@ -323,16 +323,16 @@ class TestPooling(_TestPoolingBase):
connect_timeout=1000,
socket_timeout=1000)
sock0 = cx_pool.get_socket(all_credentials={})
sock1 = cx_pool.get_socket(all_credentials={})
sock0 = cx_pool.get_socket({}, 0, 0)
sock1 = cx_pool.get_socket({}, 0, 0)
self.assertNotEqual(sock0, sock1)
# Now in a request, we'll get the same socket both times
cx_pool.start_request()
sock2 = cx_pool.get_socket(all_credentials={})
sock3 = cx_pool.get_socket(all_credentials={})
sock2 = cx_pool.get_socket({}, 0, 0)
sock3 = cx_pool.get_socket({}, 0, 0)
self.assertEqual(sock2, sock3)
# Pool didn't keep reference to sock0 or sock1; sock2 and 3 are new
@ -342,8 +342,8 @@ class TestPooling(_TestPoolingBase):
# Return the request sock to pool
cx_pool.end_request()
sock4 = cx_pool.get_socket(all_credentials={})
sock5 = cx_pool.get_socket(all_credentials={})
sock4 = cx_pool.get_socket({}, 0, 0)
sock5 = cx_pool.get_socket({}, 0, 0)
# Not in a request any more, we get different sockets
self.assertNotEqual(sock4, sock5)
@ -373,10 +373,10 @@ class TestPooling(_TestPoolingBase):
# Test Pool's _check_closed() method doesn't close a healthy socket.
cx_pool = self.create_pool(max_pool_size=10)
cx_pool._check_interval_seconds = 0 # Always check.
sock_info = cx_pool.get_socket(all_credentials={})
sock_info = cx_pool.get_socket({}, 0, 0)
cx_pool.maybe_return_socket(sock_info)
new_sock_info = cx_pool.get_socket(all_credentials={})
new_sock_info = cx_pool.get_socket({}, 0, 0)
self.assertEqual(sock_info, new_sock_info)
cx_pool.maybe_return_socket(new_sock_info)
self.assertEqual(1, len(cx_pool.sockets))
@ -387,13 +387,13 @@ class TestPooling(_TestPoolingBase):
cx_pool = self.create_pool(max_pool_size=10)
cx_pool._check_interval_seconds = 0 # Always check.
with cx_pool.get_socket(all_credentials={}) as sock_info:
with cx_pool.get_socket({}, 0, 0) as sock_info:
# Simulate a closed socket without telling the SocketInfo it's
# closed.
sock_info.sock.close()
self.assertTrue(_closed(sock_info.sock))
with cx_pool.get_socket(all_credentials={}) as new_sock_info:
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
self.assertEqual(0, len(cx_pool.sockets))
self.assertNotEqual(sock_info, new_sock_info)
@ -406,14 +406,14 @@ class TestPooling(_TestPoolingBase):
cx_pool.start_request()
# Get the request socket.
with cx_pool.get_socket(all_credentials={}) as sock_info:
with cx_pool.get_socket({}, 0, 0) as sock_info:
self.assertEqual(0, len(cx_pool.sockets))
self.assertEqual(sock_info, cx_pool._get_request_state())
sock_info.sock.close()
# Although the request socket died, we're still in a request with a
# new socket.
with cx_pool.get_socket(all_credentials={}) as new_sock_info:
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
self.assertTrue(cx_pool.in_request())
self.assertNotEqual(sock_info, new_sock_info)
self.assertEqual(new_sock_info, cx_pool._get_request_state())
@ -430,7 +430,7 @@ class TestPooling(_TestPoolingBase):
cx_pool.start_request()
# Get the request socket
with cx_pool.get_socket(all_credentials={}) as sock_info:
with cx_pool.get_socket({}, 0, 0) as sock_info:
self.assertEqual(0, len(cx_pool.sockets))
self.assertEqual(sock_info, cx_pool._get_request_state())
@ -440,7 +440,7 @@ class TestPooling(_TestPoolingBase):
# Although the request socket died, we're still in a request with a
# new socket
with cx_pool.get_socket(all_credentials={}) as new_sock_info:
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
self.assertTrue(cx_pool.in_request())
self.assertNotEqual(sock_info, new_sock_info)
self.assertEqual(new_sock_info, cx_pool._get_request_state())
@ -459,7 +459,7 @@ class TestPooling(_TestPoolingBase):
cx_pool.start_request()
# Get the request socket.
with cx_pool.get_socket(all_credentials={}) as sock_info:
with cx_pool.get_socket({}, 0, 0) as sock_info:
self.assertEqual(sock_info, cx_pool._get_request_state())
# End request.
@ -470,7 +470,7 @@ class TestPooling(_TestPoolingBase):
sock_info.sock.close()
# Dead socket detected and removed.
with cx_pool.get_socket(all_credentials={}) as new_sock_info:
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
self.assertFalse(cx_pool.in_request())
self.assertNotEqual(sock_info, new_sock_info)
self.assertEqual(0, len(cx_pool.sockets))
@ -486,17 +486,17 @@ class TestPooling(_TestPoolingBase):
cx_pool.start_request()
# Get and close the request socket.
with cx_pool.get_socket(all_credentials={}) as request_sock_info:
with cx_pool.get_socket({}, 0, 0) as request_sock_info:
request_sock_info.sock.close()
# Detects closed socket and creates new one, semaphore value still 0.
with cx_pool.get_socket(all_credentials={}) as request_sock_info_2:
with cx_pool.get_socket({}, 0, 0) as request_sock_info_2:
self.assertNotEqual(request_sock_info, request_sock_info_2)
cx_pool.end_request()
# Semaphore value now 1; we can get a socket.
sock = cx_pool.get_socket(all_credentials={})
sock = cx_pool.get_socket({}, 0, 0)
sock.close()
def test_socket_reclamation(self):
@ -519,7 +519,7 @@ class TestPooling(_TestPoolingBase):
self.assertEqual(NO_REQUEST, cx_pool._get_request_state())
cx_pool.start_request()
self.assertEqual(NO_SOCKET_YET, cx_pool._get_request_state())
with cx_pool.get_socket(all_credentials={}) as sock_info:
with cx_pool.get_socket({}, 0, 0) as sock_info:
self.assertEqual(sock_info, cx_pool._get_request_state())
the_sock[0] = id(sock_info.sock)
@ -608,8 +608,8 @@ class TestPooling(_TestPoolingBase):
def run_in_request():
p.start_request()
sock0 = p.get_socket(all_credentials={})
sock1 = p.get_socket(all_credentials={})
sock0 = p.get_socket({}, 0, 0)
sock1 = p.get_socket({}, 0, 0)
sock_ids.extend([id(sock0), id(sock1)])
p.maybe_return_socket(sock0)
p.maybe_return_socket(sock1)
@ -681,7 +681,7 @@ class TestPooling(_TestPoolingBase):
self.assertTrue(b_sock != c_sock)
# a_sock, created by parent process, is still in the pool
d_sock = get_pool(a).get_socket(all_credentials={})
d_sock = get_pool(a).get_socket({}, 0, 0)
self.assertEqual(a_sock, d_sock)
d_sock.close()
@ -690,10 +690,10 @@ class TestPooling(_TestPoolingBase):
pool = self.create_pool(
max_pool_size=1, wait_queue_timeout=wait_queue_timeout)
sock_info = pool.get_socket(all_credentials={})
sock_info = pool.get_socket({}, 0, 0)
start = time.time()
with self.assertRaises(ConnectionFailure):
pool.get_socket(all_credentials={})
pool.get_socket({}, 0, 0)
duration = time.time() - start
self.assertTrue(
@ -708,7 +708,7 @@ class TestPooling(_TestPoolingBase):
pool = self.create_pool(max_pool_size=1)
# Reach max_size.
with pool.get_socket(all_credentials={}) as s1:
with pool.get_socket({}, 0, 0) as s1:
t = SocketGetter(self.c, pool)
t.start()
while t.state != 'get_socket':
@ -730,8 +730,8 @@ class TestPooling(_TestPoolingBase):
max_pool_size=2, wait_queue_multiple=wait_queue_multiple)
# Reach max_size sockets.
socket_info_0 = pool.get_socket(all_credentials={})
socket_info_1 = pool.get_socket(all_credentials={})
socket_info_0 = pool.get_socket({}, 0, 0)
socket_info_1 = pool.get_socket({}, 0, 0)
# Reach max_size * wait_queue_multiple waiters.
threads = []
@ -745,7 +745,7 @@ class TestPooling(_TestPoolingBase):
self.assertEqual(t.state, 'get_socket')
with self.assertRaises(ExceededMaxWaiters):
pool.get_socket(all_credentials={})
pool.get_socket({}, 0, 0)
socket_info_0.close()
socket_info_1.close()
@ -754,7 +754,7 @@ class TestPooling(_TestPoolingBase):
socks = []
for _ in range(2):
sock = pool.get_socket(all_credentials={})
sock = pool.get_socket({}, 0, 0)
socks.append(sock)
threads = []
for _ in range(30):
@ -989,7 +989,7 @@ class TestPoolMaxSize(_TestPoolingBase):
# socket from pool" instead of the socket.error.
for i in range(2):
with self.assertRaises(socket.error):
test_pool.get_socket(all_credentials={})
test_pool.get_socket({}, 0, 0)
if __name__ == "__main__":

View File

@ -55,7 +55,7 @@ class MockPool(object):
self.pool_id = 0
self._lock = threading.Lock()
def get_socket(self, all_credentials):
def get_socket(self, all_credentials, min_wire_version, max_wire_version):
return MockSocketInfo()
def maybe_return_socket(self, _):

View File

@ -132,6 +132,10 @@ class TestURI(unittest.TestCase):
split_options('authMechanism=GSSAPI'))
self.assertEqual({'authmechanism': 'MONGODB-CR'},
split_options('authMechanism=MONGODB-CR'))
self.assertEqual({'authmechanism': 'SCRAM-SHA-1'},
split_options('authMechanism=SCRAM-SHA-1'))
self.assertRaises(ConfigurationError,
split_options, 'authMechanism=foo')
self.assertEqual({'authsource': 'foobar'}, split_options('authSource=foobar'))
# maxPoolSize isn't yet a documented URI option.
self.assertRaises(ConfigurationError, split_options, 'maxpoolsize=50')

View File

@ -413,15 +413,15 @@ class TestRequestMixin(object):
convenient methods for testing connection pools and requests
"""
def assertSameSock(self, pool):
sock_info0 = pool.get_socket(all_credentials={})
sock_info1 = pool.get_socket(all_credentials={})
sock_info0 = pool.get_socket({}, 0, 0)
sock_info1 = pool.get_socket({}, 0, 0)
self.assertEqual(sock_info0, sock_info1)
pool.maybe_return_socket(sock_info0)
pool.maybe_return_socket(sock_info1)
def assertDifferentSock(self, pool):
sock_info0 = pool.get_socket(all_credentials={})
sock_info1 = pool.get_socket(all_credentials={})
sock_info0 = pool.get_socket({}, 0, 0)
sock_info1 = pool.get_socket({}, 0, 0)
self.assertNotEqual(sock_info0, sock_info1)
pool.maybe_return_socket(sock_info0)
pool.maybe_return_socket(sock_info1)