PYTHON-2158 Support mechanism negotiation on the connection handshake

This commit is contained in:
Shane Harvey 2020-04-23 10:11:26 -07:00
parent 71d1227932
commit 4c727fd9c0
7 changed files with 69 additions and 40 deletions

View File

@ -562,12 +562,16 @@ def _authenticate_mongo_cr(credentials, sock_info):
def _authenticate_default(credentials, sock_info):
if sock_info.max_wire_version >= 7:
source = credentials.source
cmd = SON([
('ismaster', 1),
('saslSupportedMechs', source + '.' + credentials.username)])
mechs = sock_info.command(
source, cmd, publish_events=False).get('saslSupportedMechs', [])
if credentials in sock_info.negotiated_mechanisms:
mechs = sock_info.negotiated_mechanisms[credentials]
else:
source = credentials.source
cmd = SON([
('ismaster', 1),
('saslSupportedMechs', source + '.' + credentials.username)])
mechs = sock_info.command(
source, cmd, publish_events=False).get(
'saslSupportedMechs', [])
if 'SCRAM-SHA-256' in mechs:
return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-256')
else:

View File

@ -156,3 +156,15 @@ class IsMaster(object):
@property
def compressors(self):
return self._doc.get('compression')
@property
def sasl_supported_mechs(self):
"""Supported authentication mechanisms for the current user.
For example::
>>> ismaster.sasl_supported_mechs
["SCRAM-SHA-1", "SCRAM-SHA-256"]
"""
return self._doc.get('saslSupportedMechs', [])

View File

@ -1751,7 +1751,7 @@ class MongoClient(common.BaseObject):
maintain connection pool parameters."""
self._process_kill_cursors()
try:
self._topology.update_pool()
self._topology.update_pool(self.__all_credentials)
except Exception:
helpers._handle_exception()

View File

@ -458,6 +458,16 @@ class PoolOptions(object):
return self.__metadata.copy()
def _negotiate_creds(all_credentials):
"""Return one credential that needs mechanism negotiation, if any.
"""
if all_credentials:
for creds in all_credentials.values():
if creds.mechanism == 'DEFAULT' and creds.username:
return creds
return None
class SocketInfo(object):
"""Store a socket with some metadata.
@ -488,13 +498,16 @@ class SocketInfo(object):
self.compression_settings = pool.opts.compression_settings
self.compression_context = None
self.socket_checker = SocketChecker()
# Support for mechanism negotiation on the initial handshake.
# Maps credential to saslSupportedMechs.
self.negotiated_mechanisms = {}
# The pool's generation changes with each reset() so we can close
# sockets created before the last reset.
self.generation = pool.generation
self.ready = False
def ismaster(self, metadata, cluster_time):
def ismaster(self, metadata, cluster_time, all_credentials=None):
cmd = SON([('ismaster', 1)])
if not self.performed_handshake:
cmd['client'] = metadata
@ -504,6 +517,12 @@ class SocketInfo(object):
if self.max_wire_version >= 6 and cluster_time is not None:
cmd['$clusterTime'] = cluster_time
# XXX: Simplify in PyMongo 4.0 when all_credentials is always a single
# unchangeable value per MongoClient.
creds = _negotiate_creds(all_credentials)
if creds:
cmd['saslSupportedMechs'] = creds.source + '.' + creds.username
ismaster = IsMaster(self.command('admin', cmd, publish_events=False))
self.is_writable = ismaster.is_writable
self.max_wire_version = ismaster.max_wire_version
@ -520,6 +539,8 @@ class SocketInfo(object):
self.performed_handshake = True
self.op_msg_enabled = ismaster.max_wire_version >= 6
if creds:
self.negotiated_mechanisms[creds] = ismaster.sasl_supported_mechs
return ismaster
def command(self, dbname, spec, slave_ok=False,
@ -701,8 +722,7 @@ class SocketInfo(object):
self.authset.discard(credentials)
for credentials in cached - authset:
auth.authenticate(credentials, self)
self.authset.add(credentials)
self.authenticate(credentials)
# CMAP spec says to publish the ready event only after authenticating
# the connection.
@ -721,6 +741,8 @@ class SocketInfo(object):
"""
auth.authenticate(credentials, self)
self.authset.add(credentials)
# negotiated_mechanisms are no longer needed.
self.negotiated_mechanisms.pop(credentials, None)
def validate_session(self, client, session):
"""Validate this session before use with client.
@ -1026,7 +1048,7 @@ class Pool:
def close(self):
self._reset(close=True)
def remove_stale_sockets(self, reference_generation):
def remove_stale_sockets(self, reference_generation, all_credentials):
"""Removes stale sockets then adds new ones if pool is too small and
has not been reset. The `reference_generation` argument specifies the
`generation` at the point in time this operation was requested on the
@ -1050,7 +1072,7 @@ class Pool:
if not self._socket_semaphore.acquire(False):
break
try:
sock_info = self.connect()
sock_info = self.connect(all_credentials)
with self.lock:
# Close connection and return if the pool was reset during
# socket creation or while acquiring the pool lock.
@ -1061,7 +1083,7 @@ class Pool:
finally:
self._socket_semaphore.release()
def connect(self):
def connect(self, all_credentials=None):
"""Connect to Mongo and return a new SocketInfo.
Can raise ConnectionFailure or CertificateError.
@ -1081,9 +1103,6 @@ class Pool:
try:
sock = _configured_socket(self.address, self.opts)
except socket.error as error:
if sock is not None:
sock.close()
if self.enabled_for_cmap:
listeners.publish_connection_closed(
self.address, conn_id, ConnectionClosedReason.ERROR)
@ -1092,7 +1111,7 @@ class Pool:
sock_info = SocketInfo(sock, self, self.address, conn_id)
if self.handshake:
sock_info.ismaster(self.opts.metadata, None)
sock_info.ismaster(self.opts.metadata, None, all_credentials)
self.is_writable = sock_info.is_writable
return sock_info
@ -1123,29 +1142,23 @@ class Pool:
listeners = self.opts.event_listeners
if self.enabled_for_cmap:
listeners.publish_connection_check_out_started(self.address)
# First get a socket, then attempt authentication. Simplifies
# semaphore management in the face of network errors during auth.
sock_info = self._get_socket_no_auth()
checked_auth = False
sock_info = self._get_socket(all_credentials)
if self.enabled_for_cmap:
listeners.publish_connection_checked_out(
self.address, sock_info.id)
try:
sock_info.check_auth(all_credentials)
checked_auth = True
if self.enabled_for_cmap:
listeners.publish_connection_checked_out(
self.address, sock_info.id)
yield sock_info
except:
# Exception in caller. Decrement semaphore.
self.return_socket(sock_info, publish_checkin=checked_auth)
if self.enabled_for_cmap and not checked_auth:
self.opts.event_listeners.publish_connection_check_out_failed(
self.address, ConnectionCheckOutFailedReason.CONN_ERROR)
self.return_socket(sock_info)
raise
else:
if not checkout:
self.return_socket(sock_info)
def _get_socket_no_auth(self):
def _get_socket(self, all_credentials):
"""Get or create a SocketInfo. Can raise ConnectionFailure."""
# We use the pid here to avoid issues with fork / multiprocessing.
# See test.test_client:TestClient.test_fork for an example of
@ -1177,10 +1190,11 @@ class Pool:
sock_info = self.sockets.popleft()
except IndexError:
# Can raise ConnectionFailure or CertificateError.
sock_info = self.connect()
sock_info = self.connect(all_credentials)
else:
if self._perished(sock_info):
sock_info = None
sock_info.check_auth(all_credentials)
except Exception:
self._socket_semaphore.release()
with self.lock:
@ -1193,16 +1207,14 @@ class Pool:
return sock_info
def return_socket(self, sock_info, publish_checkin=True):
def return_socket(self, sock_info):
"""Return the socket to the pool, or if it's closed discard it.
:Parameters:
- `sock_info`: The socket to check into the pool.
- `publish_checkin`: If False, a ConnectionCheckedInEvent will not
be published.
"""
listeners = self.opts.event_listeners
if self.enabled_for_cmap and publish_checkin:
if self.enabled_for_cmap:
listeners.publish_connection_checked_in(self.address, sock_info.id)
if self.pid != os.getpid():
self.reset()

View File

@ -430,7 +430,7 @@ class Topology(object):
self._reset_server(address, reset_pool=False, error=error)
self._request_check(address)
def update_pool(self):
def update_pool(self, all_credentials):
# Remove any stale sockets and add new sockets if pool is too small.
servers = []
with self._lock:
@ -438,7 +438,7 @@ class Topology(object):
servers.append((server, server._pool.generation))
for server, generation in servers:
server._pool.remove_stale_sockets(generation)
server._pool.remove_stale_sockets(generation, all_credentials)
def close(self):
"""Clear pools and terminate monitors. Topology reopens on demand."""

View File

@ -1487,7 +1487,8 @@ class TestClient(IntegrationTest):
try:
while True:
for _ in range(10):
client._topology.update_pool()
client._topology.update_pool(
client._MongoClient__all_credentials)
if generation != pool.generation:
break
finally:

View File

@ -230,7 +230,7 @@ class MockPool(object):
def update_is_writable(self, is_writable):
pass
def remove_stale_sockets(self, reference_generation):
def remove_stale_sockets(self, *args, **kwargs):
pass