PYTHON-2158 Support mechanism negotiation on the connection handshake
This commit is contained in:
parent
71d1227932
commit
4c727fd9c0
@ -562,12 +562,16 @@ def _authenticate_mongo_cr(credentials, sock_info):
|
||||
|
||||
def _authenticate_default(credentials, sock_info):
|
||||
if sock_info.max_wire_version >= 7:
|
||||
source = credentials.source
|
||||
cmd = SON([
|
||||
('ismaster', 1),
|
||||
('saslSupportedMechs', source + '.' + credentials.username)])
|
||||
mechs = sock_info.command(
|
||||
source, cmd, publish_events=False).get('saslSupportedMechs', [])
|
||||
if credentials in sock_info.negotiated_mechanisms:
|
||||
mechs = sock_info.negotiated_mechanisms[credentials]
|
||||
else:
|
||||
source = credentials.source
|
||||
cmd = SON([
|
||||
('ismaster', 1),
|
||||
('saslSupportedMechs', source + '.' + credentials.username)])
|
||||
mechs = sock_info.command(
|
||||
source, cmd, publish_events=False).get(
|
||||
'saslSupportedMechs', [])
|
||||
if 'SCRAM-SHA-256' in mechs:
|
||||
return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-256')
|
||||
else:
|
||||
|
||||
@ -156,3 +156,15 @@ class IsMaster(object):
|
||||
@property
|
||||
def compressors(self):
|
||||
return self._doc.get('compression')
|
||||
|
||||
@property
|
||||
def sasl_supported_mechs(self):
|
||||
"""Supported authentication mechanisms for the current user.
|
||||
|
||||
For example::
|
||||
|
||||
>>> ismaster.sasl_supported_mechs
|
||||
["SCRAM-SHA-1", "SCRAM-SHA-256"]
|
||||
|
||||
"""
|
||||
return self._doc.get('saslSupportedMechs', [])
|
||||
|
||||
@ -1751,7 +1751,7 @@ class MongoClient(common.BaseObject):
|
||||
maintain connection pool parameters."""
|
||||
self._process_kill_cursors()
|
||||
try:
|
||||
self._topology.update_pool()
|
||||
self._topology.update_pool(self.__all_credentials)
|
||||
except Exception:
|
||||
helpers._handle_exception()
|
||||
|
||||
|
||||
@ -458,6 +458,16 @@ class PoolOptions(object):
|
||||
return self.__metadata.copy()
|
||||
|
||||
|
||||
def _negotiate_creds(all_credentials):
|
||||
"""Return one credential that needs mechanism negotiation, if any.
|
||||
"""
|
||||
if all_credentials:
|
||||
for creds in all_credentials.values():
|
||||
if creds.mechanism == 'DEFAULT' and creds.username:
|
||||
return creds
|
||||
return None
|
||||
|
||||
|
||||
class SocketInfo(object):
|
||||
"""Store a socket with some metadata.
|
||||
|
||||
@ -488,13 +498,16 @@ class SocketInfo(object):
|
||||
self.compression_settings = pool.opts.compression_settings
|
||||
self.compression_context = None
|
||||
self.socket_checker = SocketChecker()
|
||||
# Support for mechanism negotiation on the initial handshake.
|
||||
# Maps credential to saslSupportedMechs.
|
||||
self.negotiated_mechanisms = {}
|
||||
|
||||
# The pool's generation changes with each reset() so we can close
|
||||
# sockets created before the last reset.
|
||||
self.generation = pool.generation
|
||||
self.ready = False
|
||||
|
||||
def ismaster(self, metadata, cluster_time):
|
||||
def ismaster(self, metadata, cluster_time, all_credentials=None):
|
||||
cmd = SON([('ismaster', 1)])
|
||||
if not self.performed_handshake:
|
||||
cmd['client'] = metadata
|
||||
@ -504,6 +517,12 @@ class SocketInfo(object):
|
||||
if self.max_wire_version >= 6 and cluster_time is not None:
|
||||
cmd['$clusterTime'] = cluster_time
|
||||
|
||||
# XXX: Simplify in PyMongo 4.0 when all_credentials is always a single
|
||||
# unchangeable value per MongoClient.
|
||||
creds = _negotiate_creds(all_credentials)
|
||||
if creds:
|
||||
cmd['saslSupportedMechs'] = creds.source + '.' + creds.username
|
||||
|
||||
ismaster = IsMaster(self.command('admin', cmd, publish_events=False))
|
||||
self.is_writable = ismaster.is_writable
|
||||
self.max_wire_version = ismaster.max_wire_version
|
||||
@ -520,6 +539,8 @@ class SocketInfo(object):
|
||||
|
||||
self.performed_handshake = True
|
||||
self.op_msg_enabled = ismaster.max_wire_version >= 6
|
||||
if creds:
|
||||
self.negotiated_mechanisms[creds] = ismaster.sasl_supported_mechs
|
||||
return ismaster
|
||||
|
||||
def command(self, dbname, spec, slave_ok=False,
|
||||
@ -701,8 +722,7 @@ class SocketInfo(object):
|
||||
self.authset.discard(credentials)
|
||||
|
||||
for credentials in cached - authset:
|
||||
auth.authenticate(credentials, self)
|
||||
self.authset.add(credentials)
|
||||
self.authenticate(credentials)
|
||||
|
||||
# CMAP spec says to publish the ready event only after authenticating
|
||||
# the connection.
|
||||
@ -721,6 +741,8 @@ class SocketInfo(object):
|
||||
"""
|
||||
auth.authenticate(credentials, self)
|
||||
self.authset.add(credentials)
|
||||
# negotiated_mechanisms are no longer needed.
|
||||
self.negotiated_mechanisms.pop(credentials, None)
|
||||
|
||||
def validate_session(self, client, session):
|
||||
"""Validate this session before use with client.
|
||||
@ -1026,7 +1048,7 @@ class Pool:
|
||||
def close(self):
|
||||
self._reset(close=True)
|
||||
|
||||
def remove_stale_sockets(self, reference_generation):
|
||||
def remove_stale_sockets(self, reference_generation, all_credentials):
|
||||
"""Removes stale sockets then adds new ones if pool is too small and
|
||||
has not been reset. The `reference_generation` argument specifies the
|
||||
`generation` at the point in time this operation was requested on the
|
||||
@ -1050,7 +1072,7 @@ class Pool:
|
||||
if not self._socket_semaphore.acquire(False):
|
||||
break
|
||||
try:
|
||||
sock_info = self.connect()
|
||||
sock_info = self.connect(all_credentials)
|
||||
with self.lock:
|
||||
# Close connection and return if the pool was reset during
|
||||
# socket creation or while acquiring the pool lock.
|
||||
@ -1061,7 +1083,7 @@ class Pool:
|
||||
finally:
|
||||
self._socket_semaphore.release()
|
||||
|
||||
def connect(self):
|
||||
def connect(self, all_credentials=None):
|
||||
"""Connect to Mongo and return a new SocketInfo.
|
||||
|
||||
Can raise ConnectionFailure or CertificateError.
|
||||
@ -1081,9 +1103,6 @@ class Pool:
|
||||
try:
|
||||
sock = _configured_socket(self.address, self.opts)
|
||||
except socket.error as error:
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_closed(
|
||||
self.address, conn_id, ConnectionClosedReason.ERROR)
|
||||
@ -1092,7 +1111,7 @@ class Pool:
|
||||
|
||||
sock_info = SocketInfo(sock, self, self.address, conn_id)
|
||||
if self.handshake:
|
||||
sock_info.ismaster(self.opts.metadata, None)
|
||||
sock_info.ismaster(self.opts.metadata, None, all_credentials)
|
||||
self.is_writable = sock_info.is_writable
|
||||
|
||||
return sock_info
|
||||
@ -1123,29 +1142,23 @@ class Pool:
|
||||
listeners = self.opts.event_listeners
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_check_out_started(self.address)
|
||||
# First get a socket, then attempt authentication. Simplifies
|
||||
# semaphore management in the face of network errors during auth.
|
||||
sock_info = self._get_socket_no_auth()
|
||||
checked_auth = False
|
||||
|
||||
sock_info = self._get_socket(all_credentials)
|
||||
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_checked_out(
|
||||
self.address, sock_info.id)
|
||||
try:
|
||||
sock_info.check_auth(all_credentials)
|
||||
checked_auth = True
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_checked_out(
|
||||
self.address, sock_info.id)
|
||||
yield sock_info
|
||||
except:
|
||||
# Exception in caller. Decrement semaphore.
|
||||
self.return_socket(sock_info, publish_checkin=checked_auth)
|
||||
if self.enabled_for_cmap and not checked_auth:
|
||||
self.opts.event_listeners.publish_connection_check_out_failed(
|
||||
self.address, ConnectionCheckOutFailedReason.CONN_ERROR)
|
||||
self.return_socket(sock_info)
|
||||
raise
|
||||
else:
|
||||
if not checkout:
|
||||
self.return_socket(sock_info)
|
||||
|
||||
def _get_socket_no_auth(self):
|
||||
def _get_socket(self, all_credentials):
|
||||
"""Get or create a SocketInfo. Can raise ConnectionFailure."""
|
||||
# We use the pid here to avoid issues with fork / multiprocessing.
|
||||
# See test.test_client:TestClient.test_fork for an example of
|
||||
@ -1177,10 +1190,11 @@ class Pool:
|
||||
sock_info = self.sockets.popleft()
|
||||
except IndexError:
|
||||
# Can raise ConnectionFailure or CertificateError.
|
||||
sock_info = self.connect()
|
||||
sock_info = self.connect(all_credentials)
|
||||
else:
|
||||
if self._perished(sock_info):
|
||||
sock_info = None
|
||||
sock_info.check_auth(all_credentials)
|
||||
except Exception:
|
||||
self._socket_semaphore.release()
|
||||
with self.lock:
|
||||
@ -1193,16 +1207,14 @@ class Pool:
|
||||
|
||||
return sock_info
|
||||
|
||||
def return_socket(self, sock_info, publish_checkin=True):
|
||||
def return_socket(self, sock_info):
|
||||
"""Return the socket to the pool, or if it's closed discard it.
|
||||
|
||||
:Parameters:
|
||||
- `sock_info`: The socket to check into the pool.
|
||||
- `publish_checkin`: If False, a ConnectionCheckedInEvent will not
|
||||
be published.
|
||||
"""
|
||||
listeners = self.opts.event_listeners
|
||||
if self.enabled_for_cmap and publish_checkin:
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_checked_in(self.address, sock_info.id)
|
||||
if self.pid != os.getpid():
|
||||
self.reset()
|
||||
|
||||
@ -430,7 +430,7 @@ class Topology(object):
|
||||
self._reset_server(address, reset_pool=False, error=error)
|
||||
self._request_check(address)
|
||||
|
||||
def update_pool(self):
|
||||
def update_pool(self, all_credentials):
|
||||
# Remove any stale sockets and add new sockets if pool is too small.
|
||||
servers = []
|
||||
with self._lock:
|
||||
@ -438,7 +438,7 @@ class Topology(object):
|
||||
servers.append((server, server._pool.generation))
|
||||
|
||||
for server, generation in servers:
|
||||
server._pool.remove_stale_sockets(generation)
|
||||
server._pool.remove_stale_sockets(generation, all_credentials)
|
||||
|
||||
def close(self):
|
||||
"""Clear pools and terminate monitors. Topology reopens on demand."""
|
||||
|
||||
@ -1487,7 +1487,8 @@ class TestClient(IntegrationTest):
|
||||
try:
|
||||
while True:
|
||||
for _ in range(10):
|
||||
client._topology.update_pool()
|
||||
client._topology.update_pool(
|
||||
client._MongoClient__all_credentials)
|
||||
if generation != pool.generation:
|
||||
break
|
||||
finally:
|
||||
|
||||
@ -230,7 +230,7 @@ class MockPool(object):
|
||||
def update_is_writable(self, is_writable):
|
||||
pass
|
||||
|
||||
def remove_stale_sockets(self, reference_generation):
|
||||
def remove_stale_sockets(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user