diff --git a/pymongo/auth.py b/pymongo/auth.py index f37a0b4e5..9eca28c98 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -254,9 +254,22 @@ def _parse_scram_response(response): return dict(item.split(b"=", 1) for item in response.split(b",")) +def _authenticate_scram_start(credentials, mechanism): + username = credentials.username + user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C") + nonce = standard_b64encode(os.urandom(32)) + first_bare = b"n=" + user + b",r=" + nonce + + cmd = SON([('saslStart', 1), + ('mechanism', mechanism), + ('payload', Binary(b"n,," + first_bare)), + ('autoAuthorize', 1), + ('options', {'skipEmptyExchange': True})]) + return nonce, first_bare, cmd + + def _authenticate_scram(credentials, sock_info, mechanism): """Authenticate using SCRAM.""" - username = credentials.username if mechanism == 'SCRAM-SHA-256': digest = "sha256" @@ -272,16 +285,14 @@ def _authenticate_scram(credentials, sock_info, mechanism): # Make local _hmac = hmac.HMAC - user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C") - nonce = standard_b64encode(os.urandom(32)) - first_bare = b"n=" + user + b",r=" + nonce - - cmd = SON([('saslStart', 1), - ('mechanism', mechanism), - ('payload', Binary(b"n,," + first_bare)), - ('autoAuthorize', 1), - ('options', {'skipEmptyExchange': True})]) - res = sock_info.command(source, cmd) + ctx = sock_info.auth_ctx.get(credentials) + if ctx and ctx.speculate_succeeded(): + nonce, first_bare = ctx.scram_data + res = ctx.speculative_authenticate + else: + nonce, first_bare, cmd = _authenticate_scram_start( + credentials, mechanism) + res = sock_info.command(source, cmd) server_first = res['payload'] parsed = _parse_scram_response(server_first) @@ -516,15 +527,17 @@ def _authenticate_cram_md5(credentials, sock_info): def _authenticate_x509(credentials, sock_info): """Authenticate using MONGODB-X509. """ - query = SON([('authenticate', 1), - ('mechanism', 'MONGODB-X509')]) - if credentials.username is not None: - query['user'] = credentials.username - elif sock_info.max_wire_version < 5: + ctx = sock_info.auth_ctx.get(credentials) + if ctx and ctx.speculate_succeeded(): + # MONGODB-X509 is done after the speculative auth step. + return + + cmd = _X509Context(credentials).speculate_command() + if credentials.username is None and sock_info.max_wire_version < 5: raise ConfigurationError( "A username is required for MONGODB-X509 authentication " "when connected to MongoDB versions older than 3.4.") - sock_info.command('$external', query) + sock_info.command('$external', cmd) def _authenticate_aws(credentials, sock_info): @@ -597,6 +610,62 @@ _AUTH_MAP = { } +class _AuthContext(object): + def __init__(self, credentials): + self.credentials = credentials + self.speculative_authenticate = None + + @staticmethod + def from_credentials(creds): + spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism) + if spec_cls: + return spec_cls(creds) + return None + + def speculate_command(self): + raise NotImplementedError + + def parse_response(self, ismaster): + self.speculative_authenticate = ismaster.speculative_authenticate + + def speculate_succeeded(self): + return bool(self.speculative_authenticate) + + +class _ScramContext(_AuthContext): + def __init__(self, credentials, mechanism): + super(_ScramContext, self).__init__(credentials) + self.scram_data = None + self.mechanism = mechanism + + def speculate_command(self): + nonce, first_bare, cmd = _authenticate_scram_start( + self.credentials, self.mechanism) + # The 'db' field is included only on the speculative command. + cmd['db'] = self.credentials.source + # Save for later use. + self.scram_data = (nonce, first_bare) + return cmd + + +class _X509Context(_AuthContext): + def speculate_command(self): + cmd = SON([('authenticate', 1), + ('mechanism', 'MONGODB-X509')]) + if self.credentials.username is not None: + cmd['user'] = self.credentials.username + return cmd + + +_SPECULATIVE_AUTH_MAP = { + 'MONGODB-X509': _X509Context, + 'SCRAM-SHA-1': functools.partial(_ScramContext, mechanism='SCRAM-SHA-1'), + 'SCRAM-SHA-256': functools.partial(_ScramContext, + mechanism='SCRAM-SHA-256'), + 'DEFAULT': functools.partial(_ScramContext, mechanism='SCRAM-SHA-256'), +} + + def authenticate(credentials, sock_info): """Authenticate sock_info.""" mechanism = credentials.mechanism diff --git a/pymongo/ismaster.py b/pymongo/ismaster.py index 5223a1276..fb2d6a868 100644 --- a/pymongo/ismaster.py +++ b/pymongo/ismaster.py @@ -169,6 +169,11 @@ class IsMaster(object): """ return self._doc.get('saslSupportedMechs', []) + @property + def speculative_authenticate(self): + """The speculativeAuthenticate field.""" + return self._doc.get('speculativeAuthenticate') + @property def topology_version(self): return self._doc.get('topologyVersion') diff --git a/pymongo/pool.py b/pymongo/pool.py index b255af220..de1d682e5 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -468,6 +468,15 @@ def _negotiate_creds(all_credentials): return None +def _speculative_context(all_credentials): + """Return the _AuthContext to use for speculative auth, if any. + """ + if all_credentials and len(all_credentials) == 1: + creds = next(itervalues(all_credentials)) + return auth._AuthContext.from_credentials(creds) + return None + + class SocketInfo(object): """Store a socket with some metadata. @@ -501,6 +510,7 @@ class SocketInfo(object): # Support for mechanism negotiation on the initial handshake. # Maps credential to saslSupportedMechs. self.negotiated_mechanisms = {} + self.auth_ctx = {} # The pool's generation changes with each reset() so we can close # sockets created before the last reset. @@ -522,6 +532,9 @@ class SocketInfo(object): creds = _negotiate_creds(all_credentials) if creds: cmd['saslSupportedMechs'] = creds.source + '.' + creds.username + auth_ctx = _speculative_context(all_credentials) + if auth_ctx: + cmd['speculativeAuthenticate'] = auth_ctx.speculate_command() ismaster = IsMaster(self.command('admin', cmd, publish_events=False)) self.is_writable = ismaster.is_writable @@ -541,6 +554,10 @@ class SocketInfo(object): self.op_msg_enabled = ismaster.max_wire_version >= 6 if creds: self.negotiated_mechanisms[creds] = ismaster.sasl_supported_mechs + if auth_ctx: + auth_ctx.parse_response(ismaster) + if auth_ctx.speculate_succeeded(): + self.auth_ctx[auth_ctx.credentials] = auth_ctx return ismaster def command(self, dbname, spec, slave_ok=False, @@ -743,6 +760,7 @@ class SocketInfo(object): self.authset.add(credentials) # negotiated_mechanisms are no longer needed. self.negotiated_mechanisms.pop(credentials, None) + self.auth_ctx.pop(credentials, None) def validate_session(self, client, session): """Validate this session before use with client. diff --git a/test/test_auth.py b/test/test_auth.py index 14b2f9439..6dccc6ff1 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -418,18 +418,20 @@ class TestSCRAM(unittest.TestCase): client = rs_or_single_client_noauth( username='sha256', password='pwd', authSource='testscram', event_listeners=[listener]) - client.admin.command('isMaster') + client.testscram.command('dbstats') - # Assert we sent the skipEmptyExchange option. - first_event = listener.results['started'][0] - self.assertEqual(first_event.command_name, 'saslStart') - self.assertEqual( - first_event.command['options'], {'skipEmptyExchange': True}) + if client_context.version < (4, 4, -1): + # Assert we sent the skipEmptyExchange option. + first_event = listener.results['started'][0] + self.assertEqual(first_event.command_name, 'saslStart') + self.assertEqual( + first_event.command['options'], {'skipEmptyExchange': True}) # Assert the third exchange was skipped on servers that support it. + # Note that the first exchange occurs on the connection handshake. started = listener.started_command_names() - if client_context.version.at_least(4, 3, 3): - self.assertEqual(started, ['saslStart', 'saslContinue']) + if client_context.version.at_least(4, 4, -1): + self.assertEqual(started, ['saslContinue']) else: self.assertEqual( started, ['saslStart', 'saslContinue', 'saslContinue']) @@ -578,8 +580,13 @@ class TestSCRAM(unittest.TestCase): 'mongodb://both:pwd@%s:%d/testscram' % (host, port), event_listeners=[self.listener]) client.testscram.command('dbstats') - started = self.listener.results['started'][0] - self.assertEqual(started.command.get('mechanism'), 'SCRAM-SHA-256') + if client_context.version.at_least(4, 4, -1): + # Speculative authentication in 4.4+ sends saslStart with the + # handshake. + self.assertEqual(self.listener.results['started'], []) + else: + started = self.listener.results['started'][0] + self.assertEqual(started.command.get('mechanism'), 'SCRAM-SHA-256') client = rs_or_single_client_noauth( 'mongodb://both:pwd@%s:%d/testscram?authMechanism=SCRAM-SHA-1' diff --git a/test/test_database.py b/test/test_database.py index 0dfdccdea..15f3be70f 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -51,7 +51,8 @@ from test import (client_context, SkipTest, unittest, IntegrationTest) -from test.utils import (ignore_deprecations, +from test.utils import (EventListener, + ignore_deprecations, remove_all_users, rs_or_single_client_noauth, rs_or_single_client, @@ -677,14 +678,6 @@ class TestDatabase(IntegrationTest): admin_db_auth = self.client.admin users_db_auth = self.client.pymongo_test - # Non-root client. - client = rs_or_single_client_noauth() - admin_db = client.admin - users_db = client.pymongo_test - other_db = client.pymongo_test1 - - self.assertRaises(OperationFailure, users_db.test.find_one) - admin_db_auth.add_user( 'ro-admin', 'pass', @@ -695,15 +688,36 @@ class TestDatabase(IntegrationTest): 'user', 'pass', roles=["userAdmin", "readWrite"]) self.addCleanup(remove_all_users, users_db_auth) + # Non-root client. + listener = EventListener() + client = rs_or_single_client_noauth(event_listeners=[listener]) + admin_db = client.admin + users_db = client.pymongo_test + other_db = client.pymongo_test1 + + self.assertRaises(OperationFailure, users_db.test.find_one) + self.assertEqual(listener.started_command_names(), ['find']) + listener.reset() + # Regular user should be able to query its own db, but # no other. users_db.authenticate('user', 'pass') + if client_context.version.at_least(3, 0): + self.assertEqual(listener.started_command_names()[0], 'saslStart') + else: + self.assertEqual(listener.started_command_names()[0], 'getnonce') + self.assertEqual(0, users_db.test.count_documents({})) self.assertRaises(OperationFailure, other_db.test.find_one) + listener.reset() # Admin read-only user should be able to query any db, # but not write. admin_db.authenticate('ro-admin', 'pass') + if client_context.version.at_least(3, 0): + self.assertEqual(listener.started_command_names()[0], 'saslStart') + else: + self.assertEqual(listener.started_command_names()[0], 'getnonce') self.assertEqual(None, other_db.test.find_one()) self.assertRaises(OperationFailure, other_db.test.insert_one, {}) @@ -711,8 +725,23 @@ class TestDatabase(IntegrationTest): # Close all sockets. client.close() + listener.reset() # We should still be able to write to the regular user's db. self.assertTrue(users_db.test.delete_many({})) + names = listener.started_command_names() + if client_context.version.at_least(4, 4, -1): + # No speculation with multiple users (but we do skipEmptyExchange). + self.assertEqual( + names, ['saslStart', 'saslContinue', 'saslStart', + 'saslContinue', 'delete']) + elif client_context.version.at_least(3, 0): + self.assertEqual( + names, ['saslStart', 'saslContinue', 'saslContinue', + 'saslStart', 'saslContinue', 'saslContinue', 'delete']) + else: + self.assertEqual( + names, ['getnonce', 'authenticate', + 'getnonce', 'authenticate', 'delete']) # And read from other dbs... self.assertEqual(0, other_db.test.count_documents({})) diff --git a/test/test_ssl.py b/test/test_ssl.py index 0987354c7..006b62ec7 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -39,9 +39,11 @@ from test import (IntegrationTest, SkipTest, unittest, HAVE_IPADDRESS) -from test.utils import (remove_all_users, +from test.utils import (EventListener, cat_files, - connected) + connected, + remove_all_users) + _HAVE_PYOPENSSL = False try: @@ -582,16 +584,24 @@ class TestSSL(IntegrationTest): self.assertRaises(OperationFailure, noauth.pymongo_test.test.count) + listener = EventListener() auth = MongoClient( client_context.pair, authMechanism='MONGODB-X509', ssl=True, ssl_cert_reqs=ssl.CERT_NONE, - ssl_certfile=CLIENT_PEM) + ssl_certfile=CLIENT_PEM, + event_listeners=[listener]) if client_context.version.at_least(3, 3, 12): # No error auth.pymongo_test.test.find_one() + names = listener.started_command_names() + if client_context.version.at_least(4, 4, -1): + # Speculative auth skips the authenticate command. + self.assertEqual(names, ['find']) + else: + self.assertEqual(names, ['authenticate', 'find']) else: # Should require a username with self.assertRaises(ConfigurationError):