From e3d65107619fb5f1e6978b90c59129d8ea303c59 Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Tue, 21 Oct 2014 15:17:53 -0400 Subject: [PATCH] PYTHON-764 SCRAM-SHA-1 automatic upgrade / downgrade. --- pymongo/auth.py | 10 ++++++++- pymongo/database.py | 8 +++++-- pymongo/member.py | 25 +++++++++++++++++++++ pymongo/mongo_client.py | 33 ++++++++++++++-------------- pymongo/mongo_replica_set_client.py | 34 ++++++++++++++--------------- pymongo/pool.py | 17 +++++++++++++++ test/test_auth.py | 4 ++-- test/test_uri_parser.py | 4 ++++ 8 files changed, 96 insertions(+), 39 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index f465dc4de..046256f81 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -44,7 +44,7 @@ from pymongo.errors import ConfigurationError, OperationFailure MECHANISMS = frozenset( - ['GSSAPI', 'MONGODB-CR', 'MONGODB-X509', 'PLAIN', 'SCRAM-SHA-1']) + ['GSSAPI', 'MONGODB-CR', 'MONGODB-X509', 'PLAIN', 'SCRAM-SHA-1', 'DEFAULT']) """The authentication mechanisms supported by PyMongo.""" @@ -327,6 +327,13 @@ def _authenticate_mongo_cr(credentials, sock_info, cmd_func): cmd_func(sock_info, source, query) +def _authenticate_default(credentials, sock_info, cmd_func): + if sock_info.max_wire_version >= 3: + return _authenticate_scram_sha1(credentials, sock_info, cmd_func) + else: + return _authenticate_mongo_cr(credentials, sock_info, cmd_func) + + _AUTH_MAP = { 'CRAM-MD5': _authenticate_cram_md5, 'GSSAPI': _authenticate_gssapi, @@ -334,6 +341,7 @@ _AUTH_MAP = { 'MONGODB-X509': _authenticate_x509, 'PLAIN': _authenticate_plain, 'SCRAM-SHA-1': _authenticate_scram_sha1, + 'DEFAULT': _authenticate_default, } diff --git a/pymongo/database.py b/pymongo/database.py index b52839f04..f4cce82bd 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -842,7 +842,7 @@ class Database(common.BaseObject): raise def authenticate(self, name, password=None, - source=None, mechanism='MONGODB-CR', **kwargs): + source=None, mechanism='DEFAULT', **kwargs): """Authenticate to use this database. Authentication lasts for the life of the underlying client @@ -878,11 +878,15 @@ class Database(common.BaseObject): specified the current database is used. - `mechanism` (optional): See :data:`~pymongo.auth.MECHANISMS` for options. - Defaults to MONGODB-CR (MongoDB Challenge Response protocol) + By default, use SCRAM-SHA-1 with MongoDB 2.8 and later, + MONGODB-CR (MongoDB Challenge Response protocol) for older servers. - `gssapiServiceName` (optional): Used with the GSSAPI mechanism to specify the service name portion of the service principal name. Defaults to 'mongodb'. + .. versionadded:: 2.8 + Use SCRAM-SHA-1 with MongoDB 2.8 and later. + .. versionchanged:: 2.5 Added the `source` and `mechanism` parameters. :meth:`authenticate` now raises a subclass of :class:`~pymongo.errors.PyMongoError` if diff --git a/pymongo/member.py b/pymongo/member.py index dbdafdf0f..5ff603d9a 100644 --- a/pymongo/member.py +++ b/pymongo/member.py @@ -145,6 +145,31 @@ class Member(object): return False + def get_socket(self, force=False): + sock_info = self.pool.get_socket(force) + sock_info.set_wire_version_range(self.min_wire_version, + self.max_wire_version) + + return sock_info + + def maybe_return_socket(self, sock_info): + self.pool.maybe_return_socket(sock_info) + + def discard_socket(self, sock_info): + self.pool.discard_socket(sock_info) + + def start_request(self): + self.pool.start_request() + + def in_request(self): + return self.pool.in_request() + + def end_request(self): + self.pool.end_request() + + def reset(self): + self.pool.reset() + def __str__(self): return '' % ( self.host[0], self.host[1], self.is_primary) diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index ce7be9956..3ef5682c0 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -380,7 +380,7 @@ class MongoClient(common.BaseObject): raise ConnectionFailure(str(e)) if username: - mechanism = options.get('authmechanism', 'MONGODB-CR') + mechanism = options.get('authmechanism', 'DEFAULT') source = ( options.get('authsource') or self.__default_database_name @@ -470,7 +470,7 @@ class MongoClient(common.BaseObject): auth.authenticate(credentials, sock_info, self.__simple_command) sock_info.authset.add(credentials) finally: - member.pool.maybe_return_socket(sock_info) + member.maybe_return_socket(sock_info) self.__auth_credentials[source] = credentials @@ -914,12 +914,11 @@ class MongoClient(common.BaseObject): Calls disconnect() on error. """ - connection_pool = member.pool try: - if self.auto_start_request and not connection_pool.in_request(): - connection_pool.start_request() + if self.auto_start_request and not member.in_request(): + member.start_request() - sock_info = connection_pool.get_socket() + sock_info = member.get_socket() except socket.error, why: self.disconnect() @@ -934,7 +933,7 @@ class MongoClient(common.BaseObject): try: self.__check_auth(sock_info) except: - connection_pool.maybe_return_socket(sock_info) + member.maybe_return_socket(sock_info) raise return sock_info @@ -961,7 +960,7 @@ class MongoClient(common.BaseObject): # Close sockets promptly. if member: - member.pool.reset() + member.reset() def close(self): """Alias for :meth:`disconnect` @@ -1006,12 +1005,12 @@ class MongoClient(common.BaseObject): sock_info = None try: try: - sock_info = member.pool.get_socket() + sock_info = member.get_socket() return not pool._closed(sock_info.sock) except (socket.error, ConnectionFailure): return False finally: - member.pool.maybe_return_socket(sock_info) + member.maybe_return_socket(sock_info) def set_cursor_manager(self, manager_class): """Set this client's cursor manager. @@ -1147,7 +1146,7 @@ class MongoClient(common.BaseObject): sock_info.close() raise finally: - member.pool.maybe_return_socket(sock_info) + member.maybe_return_socket(sock_info) def __receive_data_on_socket(self, length, sock_info): """Lowest level receive operation. @@ -1215,15 +1214,15 @@ class MongoClient(common.BaseObject): if "network_timeout" in kwargs: sock_info.sock.settimeout(self.__net_timeout) - member.pool.maybe_return_socket(sock_info) + member.maybe_return_socket(sock_info) return (None, (response, sock_info, member.pool)) except (ConnectionFailure, socket.error), e: self.disconnect() - member.pool.maybe_return_socket(sock_info) + member.maybe_return_socket(sock_info) raise AutoReconnect(str(e)) except: - member.pool.maybe_return_socket(sock_info) + member.maybe_return_socket(sock_info) raise def _exhaust_next(self, sock_info): @@ -1267,7 +1266,7 @@ class MongoClient(common.BaseObject): :meth:`start_request` previously returned None """ member = self.__ensure_member() - member.pool.start_request() + member.start_request() return pool.Request(self) def in_request(self): @@ -1275,7 +1274,7 @@ class MongoClient(common.BaseObject): reserved for its exclusive use. """ member = self.__member # Don't try to connect if disconnected. - return member and member.pool.in_request() + return member and member.in_request() def end_request(self): """Undo :meth:`start_request`. If :meth:`end_request` is called as many @@ -1295,7 +1294,7 @@ class MongoClient(common.BaseObject): """ member = self.__member # Don't try to connect if disconnected. if member: - member.pool.end_request() + member.end_request() def __eq__(self, other): if isinstance(other, self.__class__): diff --git a/pymongo/mongo_replica_set_client.py b/pymongo/mongo_replica_set_client.py index a53b0af69..95c5cf9ac 100644 --- a/pymongo/mongo_replica_set_client.py +++ b/pymongo/mongo_replica_set_client.py @@ -704,7 +704,7 @@ class MongoReplicaSetClient(common.BaseObject): raise ConnectionFailure(str(e)) if username: - mechanism = options.get('authmechanism', 'MONGODB-CR') + mechanism = options.get('authmechanism', 'DEFAULT') source = ( options.get('authsource') or self.__default_database_name @@ -825,7 +825,7 @@ class MongoReplicaSetClient(common.BaseObject): auth.authenticate(credentials, sock_info, self.__simple_command) sock_info.authset.add(credentials) finally: - member.pool.maybe_return_socket(sock_info) + member.maybe_return_socket(sock_info) self.__auth_credentials[source] = credentials @@ -1151,7 +1151,7 @@ class MongoReplicaSetClient(common.BaseObject): sock_info = self.__socket(member, force=True) response, ping_time = self.__simple_command( sock_info, 'admin', {'ismaster': 1}) - member.pool.maybe_return_socket(sock_info) + member.maybe_return_socket(sock_info) new_member = member.clone_with(response, ping_time) else: response, pool, ping_time = self.__is_master(node) @@ -1189,7 +1189,7 @@ class MongoReplicaSetClient(common.BaseObject): except (ConnectionFailure, socket.error), why: if member: - member.pool.discard_socket(sock_info) + member.discard_socket(sock_info) errors.append("%s:%d: %s" % (node[0], node[1], str(why))) if hosts: break @@ -1217,7 +1217,7 @@ class MongoReplicaSetClient(common.BaseObject): # Not a member of this set. continue - member.pool.maybe_return_socket(sock_info) + member.maybe_return_socket(sock_info) new_member = member.clone_with(res, ping_time) else: res, connection_pool, ping_time = self.__is_master(host) @@ -1232,7 +1232,7 @@ class MongoReplicaSetClient(common.BaseObject): except (ConnectionFailure, socket.error): if member: - member.pool.discard_socket(sock_info) + member.discard_socket(sock_info) continue if res['ismaster']: @@ -1309,12 +1309,12 @@ class MongoReplicaSetClient(common.BaseObject): if self.auto_start_request and not self.in_request(): self.start_request() - sock_info = member.pool.get_socket(force=force) + sock_info = member.get_socket(force=force) try: self.__check_auth(sock_info) except OperationFailure: - member.pool.maybe_return_socket(sock_info) + member.maybe_return_socket(sock_info) raise return sock_info @@ -1335,7 +1335,7 @@ class MongoReplicaSetClient(common.BaseObject): """ rs_state = self.__rs_state if rs_state.primary_member: - rs_state.primary_member.pool.reset() + rs_state.primary_member.reset() threadlocal = self.__make_threadlocal() self.__rs_state = rs_state.clone_without_writer(threadlocal) @@ -1400,7 +1400,7 @@ class MongoReplicaSetClient(common.BaseObject): return False finally: if primary: - primary.pool.maybe_return_socket(sock_info) + primary.maybe_return_socket(sock_info) def __check_response_to_last_error(self, response, is_command): """Check a response to a lastError message for errors. @@ -1527,7 +1527,7 @@ class MongoReplicaSetClient(common.BaseObject): except OperationFailure: raise except(ConnectionFailure, socket.error), why: - member.pool.discard_socket(sock_info) + member.discard_socket(sock_info) if _connection_to_use in (None, -1): self.disconnect() raise AutoReconnect(str(why)) @@ -1535,7 +1535,7 @@ class MongoReplicaSetClient(common.BaseObject): sock_info.close() raise finally: - member.pool.maybe_return_socket(sock_info) + member.maybe_return_socket(sock_info) def __send_and_receive(self, member, msg, **kwargs): """Send a message on the given socket and return the response data. @@ -1557,13 +1557,13 @@ class MongoReplicaSetClient(common.BaseObject): if not exhaust: if "network_timeout" in kwargs: sock_info.sock.settimeout(self.__net_timeout) - member.pool.maybe_return_socket(sock_info) + member.maybe_return_socket(sock_info) return response, sock_info, member.pool except: if sock_info is not None: sock_info.close() - member.pool.maybe_return_socket(sock_info) + member.maybe_return_socket(sock_info) raise def __try_read(self, member, msg, **kwargs): @@ -1637,7 +1637,7 @@ class MongoReplicaSetClient(common.BaseObject): if not member: raise AutoReconnect(error_message) - return member.pool.pair, self.__try_read( + return member.pair, self.__try_read( member, msg, **kwargs) except AutoReconnect: if _connection_to_use in (-1, rs_state.writer): @@ -1766,7 +1766,7 @@ class MongoReplicaSetClient(common.BaseObject): # within a request. if 1 == self.__request_counter.inc(): for member in self.__rs_state.members: - member.pool.start_request() + member.start_request() return pool.Request(self) @@ -1795,7 +1795,7 @@ class MongoReplicaSetClient(common.BaseObject): if 0 == self.__request_counter.dec(): for member in rs_state.members: # No effect if not in a request - member.pool.end_request() + member.end_request() rs_state.unpin_host() diff --git a/pymongo/pool.py b/pymongo/pool.py index 3b728ea31..c34ebda4e 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -63,6 +63,9 @@ class SocketInfo(object): self.last_checkout = time.time() self.forced = False + self._min_wire_version = None + self._max_wire_version = None + # The pool's pool_id changes with each reset() so we can close sockets # created before the last reset. self.pool_id = pool_id @@ -74,6 +77,20 @@ class SocketInfo(object): self.sock.close() except: pass + + def set_wire_version_range(self, min_wire_version, max_wire_version): + self._min_wire_version = min_wire_version + self._max_wire_version = max_wire_version + + @property + def min_wire_version(self): + assert self._min_wire_version is not None + return self._min_wire_version + + @property + def max_wire_version(self): + assert self._max_wire_version is not None + return self._max_wire_version def __eq__(self, other): # Need to check if other is NO_REQUEST or NO_SOCKET_YET, and then check diff --git a/test/test_auth.py b/test/test_auth.py index 0ccfd0766..a566a0a29 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -583,7 +583,7 @@ class TestClientAuth(unittest.TestCase): # Simulate an authenticate() call on a different socket. credentials = auth._build_credentials_tuple( - 'MONGODB-CR', 'admin', + 'DEFAULT', 'admin', unicode('admin'), unicode('password'), {}) @@ -996,7 +996,7 @@ class TestReplicaSetClientAuth(TestReplicaSetClientBase, TestRequestMixin): # Simulate an authenticate() call on a different socket. credentials = auth._build_credentials_tuple( - 'MONGODB-CR', 'admin', + 'DEFAULT', 'admin', unicode('admin'), unicode('password'), {}) diff --git a/test/test_uri_parser.py b/test/test_uri_parser.py index c70b89837..381124446 100644 --- a/test/test_uri_parser.py +++ b/test/test_uri_parser.py @@ -131,6 +131,10 @@ class TestURI(unittest.TestCase): split_options('authMechanism=GSSAPI')) self.assertEqual({'authmechanism': 'MONGODB-CR'}, split_options('authMechanism=MONGODB-CR')) + self.assertEqual({'authmechanism': 'SCRAM-SHA-1'}, + split_options('authMechanism=SCRAM-SHA-1')) + self.assertRaises(ConfigurationError, + split_options, 'authMechanism=foo') self.assertEqual({'authsource': 'foobar'}, split_options('authSource=foobar')) # maxPoolSize isn't yet a documented URI option. self.assertRaises(ConfigurationError, split_options, 'maxpoolsize=50')