From 4c727fd9c0815583c2f827ec4393ba66cf40129a Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 23 Apr 2020 10:11:26 -0700 Subject: [PATCH] PYTHON-2158 Support mechanism negotiation on the connection handshake --- pymongo/auth.py | 16 ++++++---- pymongo/ismaster.py | 12 +++++++ pymongo/mongo_client.py | 2 +- pymongo/pool.py | 70 ++++++++++++++++++++++++----------------- pymongo/topology.py | 4 +-- test/test_client.py | 3 +- test/utils.py | 2 +- 7 files changed, 69 insertions(+), 40 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index b52c6b0af..f37a0b4e5 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -562,12 +562,16 @@ def _authenticate_mongo_cr(credentials, sock_info): def _authenticate_default(credentials, sock_info): if sock_info.max_wire_version >= 7: - source = credentials.source - cmd = SON([ - ('ismaster', 1), - ('saslSupportedMechs', source + '.' + credentials.username)]) - mechs = sock_info.command( - source, cmd, publish_events=False).get('saslSupportedMechs', []) + if credentials in sock_info.negotiated_mechanisms: + mechs = sock_info.negotiated_mechanisms[credentials] + else: + source = credentials.source + cmd = SON([ + ('ismaster', 1), + ('saslSupportedMechs', source + '.' + credentials.username)]) + mechs = sock_info.command( + source, cmd, publish_events=False).get( + 'saslSupportedMechs', []) if 'SCRAM-SHA-256' in mechs: return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-256') else: diff --git a/pymongo/ismaster.py b/pymongo/ismaster.py index e723ff0a9..e2afec307 100644 --- a/pymongo/ismaster.py +++ b/pymongo/ismaster.py @@ -156,3 +156,15 @@ class IsMaster(object): @property def compressors(self): return self._doc.get('compression') + + @property + def sasl_supported_mechs(self): + """Supported authentication mechanisms for the current user. + + For example:: + + >>> ismaster.sasl_supported_mechs + ["SCRAM-SHA-1", "SCRAM-SHA-256"] + + """ + return self._doc.get('saslSupportedMechs', []) diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index ac00a38e8..694178805 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1751,7 +1751,7 @@ class MongoClient(common.BaseObject): maintain connection pool parameters.""" self._process_kill_cursors() try: - self._topology.update_pool() + self._topology.update_pool(self.__all_credentials) except Exception: helpers._handle_exception() diff --git a/pymongo/pool.py b/pymongo/pool.py index 972a8c1db..b255af220 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -458,6 +458,16 @@ class PoolOptions(object): return self.__metadata.copy() +def _negotiate_creds(all_credentials): + """Return one credential that needs mechanism negotiation, if any. + """ + if all_credentials: + for creds in all_credentials.values(): + if creds.mechanism == 'DEFAULT' and creds.username: + return creds + return None + + class SocketInfo(object): """Store a socket with some metadata. @@ -488,13 +498,16 @@ class SocketInfo(object): self.compression_settings = pool.opts.compression_settings self.compression_context = None self.socket_checker = SocketChecker() + # Support for mechanism negotiation on the initial handshake. + # Maps credential to saslSupportedMechs. + self.negotiated_mechanisms = {} # The pool's generation changes with each reset() so we can close # sockets created before the last reset. self.generation = pool.generation self.ready = False - def ismaster(self, metadata, cluster_time): + def ismaster(self, metadata, cluster_time, all_credentials=None): cmd = SON([('ismaster', 1)]) if not self.performed_handshake: cmd['client'] = metadata @@ -504,6 +517,12 @@ class SocketInfo(object): if self.max_wire_version >= 6 and cluster_time is not None: cmd['$clusterTime'] = cluster_time + # XXX: Simplify in PyMongo 4.0 when all_credentials is always a single + # unchangeable value per MongoClient. + creds = _negotiate_creds(all_credentials) + if creds: + cmd['saslSupportedMechs'] = creds.source + '.' + creds.username + ismaster = IsMaster(self.command('admin', cmd, publish_events=False)) self.is_writable = ismaster.is_writable self.max_wire_version = ismaster.max_wire_version @@ -520,6 +539,8 @@ class SocketInfo(object): self.performed_handshake = True self.op_msg_enabled = ismaster.max_wire_version >= 6 + if creds: + self.negotiated_mechanisms[creds] = ismaster.sasl_supported_mechs return ismaster def command(self, dbname, spec, slave_ok=False, @@ -701,8 +722,7 @@ class SocketInfo(object): self.authset.discard(credentials) for credentials in cached - authset: - auth.authenticate(credentials, self) - self.authset.add(credentials) + self.authenticate(credentials) # CMAP spec says to publish the ready event only after authenticating # the connection. @@ -721,6 +741,8 @@ class SocketInfo(object): """ auth.authenticate(credentials, self) self.authset.add(credentials) + # negotiated_mechanisms are no longer needed. + self.negotiated_mechanisms.pop(credentials, None) def validate_session(self, client, session): """Validate this session before use with client. @@ -1026,7 +1048,7 @@ class Pool: def close(self): self._reset(close=True) - def remove_stale_sockets(self, reference_generation): + def remove_stale_sockets(self, reference_generation, all_credentials): """Removes stale sockets then adds new ones if pool is too small and has not been reset. The `reference_generation` argument specifies the `generation` at the point in time this operation was requested on the @@ -1050,7 +1072,7 @@ class Pool: if not self._socket_semaphore.acquire(False): break try: - sock_info = self.connect() + sock_info = self.connect(all_credentials) with self.lock: # Close connection and return if the pool was reset during # socket creation or while acquiring the pool lock. @@ -1061,7 +1083,7 @@ class Pool: finally: self._socket_semaphore.release() - def connect(self): + def connect(self, all_credentials=None): """Connect to Mongo and return a new SocketInfo. Can raise ConnectionFailure or CertificateError. @@ -1081,9 +1103,6 @@ class Pool: try: sock = _configured_socket(self.address, self.opts) except socket.error as error: - if sock is not None: - sock.close() - if self.enabled_for_cmap: listeners.publish_connection_closed( self.address, conn_id, ConnectionClosedReason.ERROR) @@ -1092,7 +1111,7 @@ class Pool: sock_info = SocketInfo(sock, self, self.address, conn_id) if self.handshake: - sock_info.ismaster(self.opts.metadata, None) + sock_info.ismaster(self.opts.metadata, None, all_credentials) self.is_writable = sock_info.is_writable return sock_info @@ -1123,29 +1142,23 @@ class Pool: listeners = self.opts.event_listeners if self.enabled_for_cmap: listeners.publish_connection_check_out_started(self.address) - # First get a socket, then attempt authentication. Simplifies - # semaphore management in the face of network errors during auth. - sock_info = self._get_socket_no_auth() - checked_auth = False + + sock_info = self._get_socket(all_credentials) + + if self.enabled_for_cmap: + listeners.publish_connection_checked_out( + self.address, sock_info.id) try: - sock_info.check_auth(all_credentials) - checked_auth = True - if self.enabled_for_cmap: - listeners.publish_connection_checked_out( - self.address, sock_info.id) yield sock_info except: # Exception in caller. Decrement semaphore. - self.return_socket(sock_info, publish_checkin=checked_auth) - if self.enabled_for_cmap and not checked_auth: - self.opts.event_listeners.publish_connection_check_out_failed( - self.address, ConnectionCheckOutFailedReason.CONN_ERROR) + self.return_socket(sock_info) raise else: if not checkout: self.return_socket(sock_info) - def _get_socket_no_auth(self): + def _get_socket(self, all_credentials): """Get or create a SocketInfo. Can raise ConnectionFailure.""" # We use the pid here to avoid issues with fork / multiprocessing. # See test.test_client:TestClient.test_fork for an example of @@ -1177,10 +1190,11 @@ class Pool: sock_info = self.sockets.popleft() except IndexError: # Can raise ConnectionFailure or CertificateError. - sock_info = self.connect() + sock_info = self.connect(all_credentials) else: if self._perished(sock_info): sock_info = None + sock_info.check_auth(all_credentials) except Exception: self._socket_semaphore.release() with self.lock: @@ -1193,16 +1207,14 @@ class Pool: return sock_info - def return_socket(self, sock_info, publish_checkin=True): + def return_socket(self, sock_info): """Return the socket to the pool, or if it's closed discard it. :Parameters: - `sock_info`: The socket to check into the pool. - - `publish_checkin`: If False, a ConnectionCheckedInEvent will not - be published. """ listeners = self.opts.event_listeners - if self.enabled_for_cmap and publish_checkin: + if self.enabled_for_cmap: listeners.publish_connection_checked_in(self.address, sock_info.id) if self.pid != os.getpid(): self.reset() diff --git a/pymongo/topology.py b/pymongo/topology.py index de446bdef..62ce0cbc2 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -430,7 +430,7 @@ class Topology(object): self._reset_server(address, reset_pool=False, error=error) self._request_check(address) - def update_pool(self): + def update_pool(self, all_credentials): # Remove any stale sockets and add new sockets if pool is too small. servers = [] with self._lock: @@ -438,7 +438,7 @@ class Topology(object): servers.append((server, server._pool.generation)) for server, generation in servers: - server._pool.remove_stale_sockets(generation) + server._pool.remove_stale_sockets(generation, all_credentials) def close(self): """Clear pools and terminate monitors. Topology reopens on demand.""" diff --git a/test/test_client.py b/test/test_client.py index 3ce5b9e41..0540786c5 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1487,7 +1487,8 @@ class TestClient(IntegrationTest): try: while True: for _ in range(10): - client._topology.update_pool() + client._topology.update_pool( + client._MongoClient__all_credentials) if generation != pool.generation: break finally: diff --git a/test/utils.py b/test/utils.py index bd5cea79e..7b5bb7fa5 100644 --- a/test/utils.py +++ b/test/utils.py @@ -230,7 +230,7 @@ class MockPool(object): def update_is_writable(self, is_writable): pass - def remove_stale_sockets(self, reference_generation): + def remove_stale_sockets(self, *args, **kwargs): pass