PYTHON-2585 Remove legacy multi-auth code (#816)
This commit is contained in:
parent
7bd9bd7b47
commit
c94a3ad1df
@ -195,7 +195,7 @@ def _authenticate_scram(credentials, sock_info, mechanism):
|
||||
# Make local
|
||||
_hmac = hmac.HMAC
|
||||
|
||||
ctx = sock_info.auth_ctx.get(credentials)
|
||||
ctx = sock_info.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
nonce, first_bare = ctx.scram_data
|
||||
res = ctx.speculative_authenticate
|
||||
@ -424,7 +424,7 @@ def _authenticate_plain(credentials, sock_info):
|
||||
def _authenticate_x509(credentials, sock_info):
|
||||
"""Authenticate using MONGODB-X509.
|
||||
"""
|
||||
ctx = sock_info.auth_ctx.get(credentials)
|
||||
ctx = sock_info.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
# MONGODB-X509 is done after the speculative auth step.
|
||||
return
|
||||
@ -454,8 +454,8 @@ def _authenticate_mongo_cr(credentials, sock_info):
|
||||
|
||||
def _authenticate_default(credentials, sock_info):
|
||||
if sock_info.max_wire_version >= 7:
|
||||
if credentials in sock_info.negotiated_mechanisms:
|
||||
mechs = sock_info.negotiated_mechanisms[credentials]
|
||||
if sock_info.negotiated_mechs:
|
||||
mechs = sock_info.negotiated_mechs
|
||||
else:
|
||||
source = credentials.source
|
||||
cmd = sock_info.hello_cmd()
|
||||
|
||||
@ -117,8 +117,9 @@ def _parse_ssl_options(options):
|
||||
return None, allow_invalid_hostnames
|
||||
|
||||
|
||||
def _parse_pool_options(options):
|
||||
def _parse_pool_options(username, password, database, options):
|
||||
"""Parse connection pool options."""
|
||||
credentials = _parse_credentials(username, password, database, options)
|
||||
max_pool_size = options.get('maxpoolsize', common.MAX_POOL_SIZE)
|
||||
min_pool_size = options.get('minpoolsize', common.MIN_POOL_SIZE)
|
||||
max_idle_time_seconds = options.get(
|
||||
@ -151,7 +152,8 @@ def _parse_pool_options(options):
|
||||
compression_settings,
|
||||
max_connecting=max_connecting,
|
||||
server_api=server_api,
|
||||
load_balanced=load_balanced)
|
||||
load_balanced=load_balanced,
|
||||
credentials=credentials)
|
||||
|
||||
|
||||
class ClientOptions(object):
|
||||
@ -164,10 +166,7 @@ class ClientOptions(object):
|
||||
|
||||
def __init__(self, username, password, database, options):
|
||||
self.__options = options
|
||||
|
||||
self.__codec_options = _parse_codec_options(options)
|
||||
self.__credentials = _parse_credentials(
|
||||
username, password, database, options)
|
||||
self.__direct_connection = options.get('directconnection')
|
||||
self.__local_threshold_ms = options.get(
|
||||
'localthresholdms', common.LOCAL_THRESHOLD_MS)
|
||||
@ -175,7 +174,8 @@ class ClientOptions(object):
|
||||
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
|
||||
self.__server_selection_timeout = options.get(
|
||||
'serverselectiontimeoutms', common.SERVER_SELECTION_TIMEOUT)
|
||||
self.__pool_options = _parse_pool_options(options)
|
||||
self.__pool_options = _parse_pool_options(
|
||||
username, password, database, options)
|
||||
self.__read_preference = _parse_read_preference(options)
|
||||
self.__replica_set_name = options.get('replicaset')
|
||||
self.__write_concern = _parse_write_concern(options)
|
||||
@ -205,11 +205,6 @@ class ClientOptions(object):
|
||||
"""A :class:`~bson.codec_options.CodecOptions` instance."""
|
||||
return self.__codec_options
|
||||
|
||||
@property
|
||||
def _credentials(self):
|
||||
"""A :class:`~pymongo.auth.MongoCredentials` instance or None."""
|
||||
return self.__credentials
|
||||
|
||||
@property
|
||||
def direct_connection(self):
|
||||
"""Whether to connect to the deployment in 'Single' topology."""
|
||||
|
||||
@ -729,11 +729,6 @@ class MongoClient(common.BaseObject):
|
||||
options.write_concern,
|
||||
options.read_concern)
|
||||
|
||||
self.__all_credentials = {}
|
||||
creds = options._credentials
|
||||
if creds:
|
||||
self.__all_credentials[creds.source] = creds
|
||||
|
||||
self._topology_settings = TopologySettings(
|
||||
seeds=seeds,
|
||||
replica_set_name=options.replica_set_name,
|
||||
@ -1090,8 +1085,7 @@ class MongoClient(common.BaseObject):
|
||||
if in_txn and session._pinned_connection:
|
||||
yield session._pinned_connection
|
||||
return
|
||||
with server.get_socket(
|
||||
self.__all_credentials, handler=err_handler) as sock_info:
|
||||
with server.get_socket(handler=err_handler) as sock_info:
|
||||
# Pin this session to the selected server or connection.
|
||||
if (in_txn and server.description.server_type in (
|
||||
SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer)):
|
||||
@ -1535,7 +1529,7 @@ class MongoClient(common.BaseObject):
|
||||
maintain connection pool parameters."""
|
||||
try:
|
||||
self._process_kill_cursors()
|
||||
self._topology.update_pool(self.__all_credentials)
|
||||
self._topology.update_pool()
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
return
|
||||
|
||||
@ -246,7 +246,7 @@ class Monitor(MonitorBase):
|
||||
|
||||
if self._cancel_context and self._cancel_context.cancelled:
|
||||
self._reset_connection()
|
||||
with self._pool.get_socket({}) as sock_info:
|
||||
with self._pool.get_socket() as sock_info:
|
||||
self._cancel_context = sock_info.cancel_context
|
||||
response, round_trip_time = self._check_with_socket(sock_info)
|
||||
if not response.awaitable:
|
||||
@ -275,11 +275,10 @@ class Monitor(MonitorBase):
|
||||
response = conn._hello(
|
||||
cluster_time,
|
||||
self._server_description.topology_version,
|
||||
self._settings.heartbeat_frequency,
|
||||
None)
|
||||
self._settings.heartbeat_frequency)
|
||||
else:
|
||||
# New connection handshake or polling hello (MongoDB <4.4).
|
||||
response = conn._hello(cluster_time, None, None, None)
|
||||
response = conn._hello(cluster_time, None, None)
|
||||
return response, time.monotonic() - start
|
||||
|
||||
|
||||
@ -388,7 +387,7 @@ class _RttMonitor(MonitorBase):
|
||||
|
||||
def _ping(self):
|
||||
"""Run a "hello" command and return the RTT."""
|
||||
with self._pool.get_socket({}) as sock_info:
|
||||
with self._pool.get_socket() as sock_info:
|
||||
if self._executor._stopped:
|
||||
raise Exception('_RttMonitor closed')
|
||||
start = time.monotonic()
|
||||
|
||||
122
pymongo/pool.py
122
pymongo/pool.py
@ -275,7 +275,8 @@ class PoolOptions(object):
|
||||
'__ssl_context', '__tls_allow_invalid_hostnames',
|
||||
'__event_listeners', '__appname', '__driver', '__metadata',
|
||||
'__compression_settings', '__max_connecting',
|
||||
'__pause_enabled', '__server_api', '__load_balanced')
|
||||
'__pause_enabled', '__server_api', '__load_balanced',
|
||||
'__credentials')
|
||||
|
||||
def __init__(self, max_pool_size=MAX_POOL_SIZE,
|
||||
min_pool_size=MIN_POOL_SIZE,
|
||||
@ -285,7 +286,8 @@ class PoolOptions(object):
|
||||
tls_allow_invalid_hostnames=False,
|
||||
event_listeners=None, appname=None, driver=None,
|
||||
compression_settings=None, max_connecting=MAX_CONNECTING,
|
||||
pause_enabled=True, server_api=None, load_balanced=None):
|
||||
pause_enabled=True, server_api=None, load_balanced=None,
|
||||
credentials=None):
|
||||
self.__max_pool_size = max_pool_size
|
||||
self.__min_pool_size = min_pool_size
|
||||
self.__max_idle_time_seconds = max_idle_time_seconds
|
||||
@ -302,6 +304,7 @@ class PoolOptions(object):
|
||||
self.__pause_enabled = pause_enabled
|
||||
self.__server_api = server_api
|
||||
self.__load_balanced = load_balanced
|
||||
self.__credentials = credentials
|
||||
self.__metadata = copy.deepcopy(_METADATA)
|
||||
if appname:
|
||||
self.__metadata['application'] = {'name': appname}
|
||||
@ -325,6 +328,11 @@ class PoolOptions(object):
|
||||
self.__metadata['platform'] = "%s|%s" % (
|
||||
_METADATA['platform'], driver.platform)
|
||||
|
||||
@property
|
||||
def _credentials(self):
|
||||
"""A :class:`~pymongo.auth.MongoCredentials` instance or None."""
|
||||
return self.__credentials
|
||||
|
||||
@property
|
||||
def non_default_options(self):
|
||||
"""The non-default options this pool was created with.
|
||||
@ -457,25 +465,6 @@ class PoolOptions(object):
|
||||
return self.__load_balanced
|
||||
|
||||
|
||||
def _negotiate_creds(all_credentials):
|
||||
"""Return one credential that needs mechanism negotiation, if any.
|
||||
"""
|
||||
if all_credentials:
|
||||
for creds in all_credentials.values():
|
||||
if creds.mechanism == 'DEFAULT' and creds.username:
|
||||
return creds
|
||||
return None
|
||||
|
||||
|
||||
def _speculative_context(all_credentials):
|
||||
"""Return the _AuthContext to use for speculative auth, if any.
|
||||
"""
|
||||
if all_credentials and len(all_credentials) == 1:
|
||||
creds = next(iter(all_credentials.values()))
|
||||
return auth._AuthContext.from_credentials(creds)
|
||||
return None
|
||||
|
||||
|
||||
class _CancellationContext(object):
|
||||
def __init__(self):
|
||||
self._cancelled = False
|
||||
@ -504,7 +493,7 @@ class SocketInfo(object):
|
||||
self.sock = sock
|
||||
self.address = address
|
||||
self.id = id
|
||||
self.authset = set()
|
||||
self.authed = set()
|
||||
self.closed = False
|
||||
self.last_checkin_time = time.monotonic()
|
||||
self.performed_handshake = False
|
||||
@ -523,9 +512,8 @@ class SocketInfo(object):
|
||||
self.compression_context = None
|
||||
self.socket_checker = SocketChecker()
|
||||
# Support for mechanism negotiation on the initial handshake.
|
||||
# Maps credential to saslSupportedMechs.
|
||||
self.negotiated_mechanisms = {}
|
||||
self.auth_ctx = {}
|
||||
self.negotiated_mechs = None
|
||||
self.auth_ctx = None
|
||||
|
||||
# The pool's generation changes with each reset() so we can close
|
||||
# sockets created before the last reset.
|
||||
@ -567,11 +555,10 @@ class SocketInfo(object):
|
||||
else:
|
||||
return SON([(HelloCompat.LEGACY_CMD, 1), ('helloOk', True)])
|
||||
|
||||
def hello(self, all_credentials=None):
|
||||
return self._hello(None, None, None, all_credentials)
|
||||
def hello(self):
|
||||
return self._hello(None, None, None)
|
||||
|
||||
def _hello(self, cluster_time, topology_version,
|
||||
heartbeat_frequency, all_credentials):
|
||||
def _hello(self, cluster_time, topology_version, heartbeat_frequency):
|
||||
cmd = self.hello_cmd()
|
||||
performing_handshake = not self.performed_handshake
|
||||
awaitable = False
|
||||
@ -594,14 +581,15 @@ class SocketInfo(object):
|
||||
if not performing_handshake and cluster_time is not None:
|
||||
cmd['$clusterTime'] = cluster_time
|
||||
|
||||
# XXX: Simplify in PyMongo 4.0 when all_credentials is always a single
|
||||
# unchangeable value per MongoClient.
|
||||
creds = _negotiate_creds(all_credentials)
|
||||
creds = self.opts._credentials
|
||||
if creds:
|
||||
cmd['saslSupportedMechs'] = creds.source + '.' + creds.username
|
||||
auth_ctx = _speculative_context(all_credentials)
|
||||
if auth_ctx:
|
||||
cmd['speculativeAuthenticate'] = auth_ctx.speculate_command()
|
||||
if creds.mechanism == 'DEFAULT' and creds.username:
|
||||
cmd['saslSupportedMechs'] = creds.source + '.' + creds.username
|
||||
auth_ctx = auth._AuthContext.from_credentials(creds)
|
||||
if auth_ctx:
|
||||
cmd['speculativeAuthenticate'] = auth_ctx.speculate_command()
|
||||
else:
|
||||
auth_ctx = None
|
||||
|
||||
doc = self.command('admin', cmd, publish_events=False,
|
||||
exhaust_allowed=awaitable)
|
||||
@ -628,11 +616,11 @@ class SocketInfo(object):
|
||||
|
||||
self.op_msg_enabled = True
|
||||
if creds:
|
||||
self.negotiated_mechanisms[creds] = hello.sasl_supported_mechs
|
||||
self.negotiated_mechs = hello.sasl_supported_mechs
|
||||
if auth_ctx:
|
||||
auth_ctx.parse_response(hello)
|
||||
if auth_ctx.speculate_succeeded():
|
||||
self.auth_ctx[auth_ctx.credentials] = auth_ctx
|
||||
self.auth_ctx = auth_ctx
|
||||
if self.opts.load_balanced:
|
||||
if not hello.service_id:
|
||||
raise ConfigurationError(
|
||||
@ -799,41 +787,21 @@ class SocketInfo(object):
|
||||
helpers._check_command_response(result, self.max_wire_version)
|
||||
return result
|
||||
|
||||
def check_auth(self, all_credentials):
|
||||
"""Update this socket's authentication.
|
||||
def authenticate(self):
|
||||
"""Authenticate to the server if needed.
|
||||
|
||||
Log in or out to bring this socket's credentials up to date with
|
||||
those provided. Can raise ConnectionFailure or OperationFailure.
|
||||
|
||||
:Parameters:
|
||||
- `all_credentials`: dict, maps auth source to MongoCredential.
|
||||
Can raise ConnectionFailure or OperationFailure.
|
||||
"""
|
||||
if all_credentials:
|
||||
for credentials in all_credentials.values():
|
||||
if credentials not in self.authset:
|
||||
self.authenticate(credentials)
|
||||
|
||||
# CMAP spec says to publish the ready event only after authenticating
|
||||
# the connection.
|
||||
if not self.ready:
|
||||
creds = self.opts._credentials
|
||||
if creds:
|
||||
auth.authenticate(creds, self)
|
||||
self.ready = True
|
||||
if self.enabled_for_cmap:
|
||||
self.listeners.publish_connection_ready(self.address, self.id)
|
||||
|
||||
def authenticate(self, credentials):
|
||||
"""Log in to the server and store these credentials in `authset`.
|
||||
|
||||
Can raise ConnectionFailure or OperationFailure.
|
||||
|
||||
:Parameters:
|
||||
- `credentials`: A MongoCredential.
|
||||
"""
|
||||
auth.authenticate(credentials, self)
|
||||
self.authset.add(credentials)
|
||||
# negotiated_mechanisms are no longer needed.
|
||||
self.negotiated_mechanisms.pop(credentials, None)
|
||||
self.auth_ctx.pop(credentials, None)
|
||||
|
||||
def validate_session(self, client, session):
|
||||
"""Validate this session before use with client.
|
||||
|
||||
@ -1245,7 +1213,7 @@ class Pool:
|
||||
def stale_generation(self, gen, service_id):
|
||||
return self.gen.stale(gen, service_id)
|
||||
|
||||
def remove_stale_sockets(self, reference_generation, all_credentials):
|
||||
def remove_stale_sockets(self, reference_generation):
|
||||
"""Removes stale sockets then adds new ones if pool is too small and
|
||||
has not been reset. The `reference_generation` argument specifies the
|
||||
`generation` at the point in time this operation was requested on the
|
||||
@ -1281,7 +1249,7 @@ class Pool:
|
||||
return
|
||||
self._pending += 1
|
||||
incremented = True
|
||||
sock_info = self.connect(all_credentials)
|
||||
sock_info = self.connect()
|
||||
with self.lock:
|
||||
# Close connection and return if the pool was reset during
|
||||
# socket creation or while acquiring the pool lock.
|
||||
@ -1300,7 +1268,7 @@ class Pool:
|
||||
self.requests -= 1
|
||||
self.size_cond.notify()
|
||||
|
||||
def connect(self, all_credentials=None):
|
||||
def connect(self):
|
||||
"""Connect to Mongo and return a new SocketInfo.
|
||||
|
||||
Can raise ConnectionFailure.
|
||||
@ -1331,10 +1299,10 @@ class Pool:
|
||||
sock_info = SocketInfo(sock, self, self.address, conn_id)
|
||||
try:
|
||||
if self.handshake:
|
||||
sock_info.hello(all_credentials)
|
||||
sock_info.hello()
|
||||
self.is_writable = sock_info.is_writable
|
||||
|
||||
sock_info.check_auth(all_credentials)
|
||||
sock_info.authenticate()
|
||||
except BaseException:
|
||||
sock_info.close_socket(ConnectionClosedReason.ERROR)
|
||||
raise
|
||||
@ -1342,7 +1310,7 @@ class Pool:
|
||||
return sock_info
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_socket(self, all_credentials, handler=None):
|
||||
def get_socket(self, handler=None):
|
||||
"""Get a socket from the pool. Use with a "with" statement.
|
||||
|
||||
Returns a :class:`SocketInfo` object wrapping a connected
|
||||
@ -1350,25 +1318,20 @@ class Pool:
|
||||
|
||||
This method should always be used in a with-statement::
|
||||
|
||||
with pool.get_socket(credentials) as socket_info:
|
||||
with pool.get_socket() as socket_info:
|
||||
socket_info.send_message(msg)
|
||||
data = socket_info.receive_message(op_code, request_id)
|
||||
|
||||
The socket is logged in or out as needed to match ``all_credentials``
|
||||
using the correct authentication mechanism for the server's wire
|
||||
protocol version.
|
||||
|
||||
Can raise ConnectionFailure or OperationFailure.
|
||||
|
||||
:Parameters:
|
||||
- `all_credentials`: dict, maps auth source to MongoCredential.
|
||||
- `handler` (optional): A _MongoClientErrorHandler.
|
||||
"""
|
||||
listeners = self.opts._event_listeners
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_check_out_started(self.address)
|
||||
|
||||
sock_info = self._get_socket(all_credentials)
|
||||
sock_info = self._get_socket()
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_checked_out(
|
||||
self.address, sock_info.id)
|
||||
@ -1407,7 +1370,7 @@ class Pool:
|
||||
_raise_connection_failure(
|
||||
self.address, AutoReconnect('connection pool paused'))
|
||||
|
||||
def _get_socket(self, all_credentials):
|
||||
def _get_socket(self):
|
||||
"""Get or create a SocketInfo. Can raise ConnectionFailure."""
|
||||
# We use the pid here to avoid issues with fork / multiprocessing.
|
||||
# See test.test_client:TestClient.test_fork for an example of
|
||||
@ -1480,12 +1443,11 @@ class Pool:
|
||||
continue
|
||||
else: # We need to create a new connection
|
||||
try:
|
||||
sock_info = self.connect(all_credentials)
|
||||
sock_info = self.connect()
|
||||
finally:
|
||||
with self._max_connecting_cond:
|
||||
self._pending -= 1
|
||||
self._max_connecting_cond.notify()
|
||||
sock_info.check_auth(all_credentials)
|
||||
except BaseException:
|
||||
if sock_info:
|
||||
# We checked out a socket but authentication failed.
|
||||
|
||||
@ -77,9 +77,9 @@ class Server(object):
|
||||
Can raise ConnectionFailure, OperationFailure, etc.
|
||||
|
||||
:Parameters:
|
||||
- `sock_info` - A SocketInfo instance.
|
||||
- `operation`: A _Query or _GetMore object.
|
||||
- `set_secondary_okay`: Pass to operation.get_message.
|
||||
- `all_credentials`: dict, maps auth source to MongoCredential.
|
||||
- `listeners`: Instance of _EventListeners or None.
|
||||
- `unpack_res`: A callable that decodes the wire protocol response.
|
||||
"""
|
||||
@ -200,8 +200,8 @@ class Server(object):
|
||||
|
||||
return response
|
||||
|
||||
def get_socket(self, all_credentials, handler=None):
|
||||
return self.pool.get_socket(all_credentials, handler)
|
||||
def get_socket(self, handler=None):
|
||||
return self.pool.get_socket(handler)
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
|
||||
@ -444,7 +444,7 @@ class Topology(object):
|
||||
return self._description.known_servers
|
||||
return self._description.readable_servers
|
||||
|
||||
def update_pool(self, all_credentials):
|
||||
def update_pool(self):
|
||||
# Remove any stale sockets and add new sockets if pool is too small.
|
||||
servers = []
|
||||
with self._lock:
|
||||
@ -456,7 +456,7 @@ class Topology(object):
|
||||
|
||||
for server, generation in servers:
|
||||
try:
|
||||
server.pool.remove_stale_sockets(generation, all_credentials)
|
||||
server.pool.remove_stale_sockets(generation)
|
||||
except PyMongoError as exc:
|
||||
ctx = _ErrorContext(exc, 0, generation, False, None)
|
||||
self.handle_error(server.description.address, ctx)
|
||||
|
||||
@ -40,7 +40,7 @@ class MockPool(Pool):
|
||||
Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_socket(self, all_credentials, handler=None):
|
||||
def get_socket(self, handler=None):
|
||||
client = self.client
|
||||
host_and_port = '%s:%s' % (self.mock_host, self.mock_port)
|
||||
if host_and_port in client.mock_down_hosts:
|
||||
@ -51,7 +51,7 @@ class MockPool(Pool):
|
||||
+ client.mock_members
|
||||
+ client.mock_mongoses), "bad host: %s" % host_and_port
|
||||
|
||||
with Pool.get_socket(self, all_credentials, handler) as sock_info:
|
||||
with Pool.get_socket(self, handler) as sock_info:
|
||||
sock_info.mock_host = self.mock_host
|
||||
sock_info.mock_port = self.mock_port
|
||||
yield sock_info
|
||||
|
||||
@ -30,6 +30,7 @@ from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.saslprep import HAVE_STRINGPREP
|
||||
from test import client_context, IntegrationTest, SkipTest, unittest, Version
|
||||
from test.utils import (delay,
|
||||
get_pool,
|
||||
ignore_deprecations,
|
||||
single_client,
|
||||
rs_or_single_client,
|
||||
@ -521,10 +522,12 @@ class TestSCRAM(IntegrationTest):
|
||||
|
||||
def test_cache(self):
|
||||
client = single_client()
|
||||
credentials = client.options.pool_options._credentials
|
||||
cache = credentials.cache
|
||||
self.assertIsNotNone(cache)
|
||||
self.assertIsNone(cache.data)
|
||||
# Force authentication.
|
||||
client.admin.command('ping')
|
||||
all_credentials = client._MongoClient__all_credentials
|
||||
credentials = all_credentials.get('admin')
|
||||
cache = credentials.cache
|
||||
self.assertIsNotNone(cache)
|
||||
data = cache.data
|
||||
@ -536,19 +539,6 @@ class TestSCRAM(IntegrationTest):
|
||||
self.assertIsInstance(salt, bytes)
|
||||
self.assertIsInstance(iterations, int)
|
||||
|
||||
pool = next(iter(client._topology._servers.values()))._pool
|
||||
with pool.get_socket(all_credentials) as sock_info:
|
||||
authset = sock_info.authset
|
||||
cached = set(all_credentials.values())
|
||||
self.assertEqual(len(cached), 1)
|
||||
self.assertFalse(authset - cached)
|
||||
self.assertFalse(cached - authset)
|
||||
|
||||
sock_credentials = next(iter(authset))
|
||||
sock_cache = sock_credentials.cache
|
||||
self.assertIsNotNone(sock_cache)
|
||||
self.assertEqual(sock_cache.data, data)
|
||||
|
||||
def test_scram_threaded(self):
|
||||
|
||||
coll = client_context.client.db.test
|
||||
|
||||
@ -44,7 +44,7 @@ def create_test(test_case):
|
||||
self.assertRaises(Exception, MongoClient, uri, connect=False)
|
||||
else:
|
||||
client = MongoClient(uri, connect=False)
|
||||
credentials = client._MongoClient__options._credentials
|
||||
credentials = client.options.pool_options._credentials
|
||||
if credential is None:
|
||||
self.assertIsNone(credentials)
|
||||
else:
|
||||
|
||||
@ -498,7 +498,7 @@ class TestClient(IntegrationTest):
|
||||
client = rs_or_single_client()
|
||||
server = client._get_topology().select_server(
|
||||
readable_server_selector)
|
||||
with server._pool.get_socket({}) as sock_info:
|
||||
with server._pool.get_socket() as sock_info:
|
||||
pass
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
self.assertTrue(sock_info in server._pool.sockets)
|
||||
@ -511,7 +511,7 @@ class TestClient(IntegrationTest):
|
||||
minPoolSize=1)
|
||||
server = client._get_topology().select_server(
|
||||
readable_server_selector)
|
||||
with server._pool.get_socket({}) as sock_info:
|
||||
with server._pool.get_socket() as sock_info:
|
||||
pass
|
||||
# When the reaper runs at the same time as the get_socket, two
|
||||
# sockets could be created and checked into the pool.
|
||||
@ -530,7 +530,7 @@ class TestClient(IntegrationTest):
|
||||
maxPoolSize=1)
|
||||
server = client._get_topology().select_server(
|
||||
readable_server_selector)
|
||||
with server._pool.get_socket({}) as sock_info:
|
||||
with server._pool.get_socket() as sock_info:
|
||||
pass
|
||||
# When the reaper runs at the same time as the get_socket,
|
||||
# maxPoolSize=1 should prevent two sockets from being created.
|
||||
@ -547,11 +547,11 @@ class TestClient(IntegrationTest):
|
||||
client = rs_or_single_client(maxIdleTimeMS=500)
|
||||
server = client._get_topology().select_server(
|
||||
readable_server_selector)
|
||||
with server._pool.get_socket({}) as sock_info_one:
|
||||
with server._pool.get_socket() as sock_info_one:
|
||||
pass
|
||||
# Assert that the pool does not close sockets prematurely.
|
||||
time.sleep(.300)
|
||||
with server._pool.get_socket({}) as sock_info_two:
|
||||
with server._pool.get_socket() as sock_info_two:
|
||||
pass
|
||||
self.assertIs(sock_info_one, sock_info_two)
|
||||
wait_until(
|
||||
@ -574,7 +574,7 @@ class TestClient(IntegrationTest):
|
||||
"pool initialized with 10 sockets")
|
||||
|
||||
# Assert that if a socket is closed, a new one takes its place
|
||||
with server._pool.get_socket({}) as sock_info:
|
||||
with server._pool.get_socket() as sock_info:
|
||||
sock_info.close_socket(None)
|
||||
wait_until(lambda: 10 == len(server._pool.sockets),
|
||||
"a closed socket gets replaced from the pool")
|
||||
@ -586,12 +586,12 @@ class TestClient(IntegrationTest):
|
||||
client = rs_or_single_client(maxIdleTimeMS=500)
|
||||
server = client._get_topology().select_server(
|
||||
readable_server_selector)
|
||||
with server._pool.get_socket({}) as sock_info:
|
||||
with server._pool.get_socket() as sock_info:
|
||||
pass
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
time.sleep(1) # Sleep so that the socket becomes stale.
|
||||
|
||||
with server._pool.get_socket({}) as new_sock_info:
|
||||
with server._pool.get_socket() as new_sock_info:
|
||||
self.assertNotEqual(sock_info, new_sock_info)
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
self.assertFalse(sock_info in server._pool.sockets)
|
||||
@ -601,11 +601,11 @@ class TestClient(IntegrationTest):
|
||||
client = rs_or_single_client()
|
||||
server = client._get_topology().select_server(
|
||||
readable_server_selector)
|
||||
with server._pool.get_socket({}) as sock_info:
|
||||
with server._pool.get_socket() as sock_info:
|
||||
pass
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
time.sleep(1)
|
||||
with server._pool.get_socket({}) as new_sock_info:
|
||||
with server._pool.get_socket() as new_sock_info:
|
||||
self.assertEqual(sock_info, new_sock_info)
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
|
||||
@ -1106,7 +1106,7 @@ class TestClient(IntegrationTest):
|
||||
|
||||
def test_socketKeepAlive(self):
|
||||
pool = get_pool(self.client)
|
||||
with pool.get_socket({}) as sock_info:
|
||||
with pool.get_socket() as sock_info:
|
||||
keepalive = sock_info.sock.getsockopt(socket.SOL_SOCKET,
|
||||
socket.SO_KEEPALIVE)
|
||||
self.assertTrue(keepalive)
|
||||
@ -1325,8 +1325,8 @@ class TestClient(IntegrationTest):
|
||||
socket_info = one(pool.sockets)
|
||||
socket_info.sock.close()
|
||||
|
||||
# SocketInfo.check_auth logs in with the new credential, but gets a
|
||||
# socket.error. Should be reraised as AutoReconnect.
|
||||
# SocketInfo.authenticate logs, but gets a socket.error. Should be
|
||||
# reraised as AutoReconnect.
|
||||
self.assertRaises(AutoReconnect, c.test.collection.find_one)
|
||||
|
||||
# No semaphore leak, the pool is allowed to make a new socket.
|
||||
@ -1521,8 +1521,7 @@ class TestClient(IntegrationTest):
|
||||
try:
|
||||
while True:
|
||||
for _ in range(10):
|
||||
client._topology.update_pool(
|
||||
client._MongoClient__all_credentials)
|
||||
client._topology.update_pool()
|
||||
if generation != pool.gen.get_overall():
|
||||
break
|
||||
finally:
|
||||
|
||||
@ -120,7 +120,7 @@ class TestCMAP(IntegrationTest):
|
||||
def check_out(self, op):
|
||||
"""Run the 'checkOut' operation."""
|
||||
label = op['label']
|
||||
with self.pool.get_socket({}) as sock_info:
|
||||
with self.pool.get_socket() as sock_info:
|
||||
# Call 'pin_cursor' so we can hold the socket.
|
||||
sock_info.pin_cursor()
|
||||
if label:
|
||||
@ -452,7 +452,7 @@ class TestCMAP(IntegrationTest):
|
||||
self.assertEqual(1, listener.event_count(PoolClearedEvent))
|
||||
self.assertEqual(PoolState.READY, pool.state)
|
||||
# Checking out a connection should succeed
|
||||
with pool.get_socket({}):
|
||||
with pool.get_socket():
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@ -118,7 +118,7 @@ class SocketGetter(MongoThread):
|
||||
self.state = 'get_socket'
|
||||
|
||||
# Call 'pin_cursor' so we can hold the socket.
|
||||
with self.pool.get_socket({}) as sock:
|
||||
with self.pool.get_socket() as sock:
|
||||
sock.pin_cursor()
|
||||
self.sock = sock
|
||||
|
||||
@ -196,10 +196,10 @@ class TestPooling(_TestPoolingBase):
|
||||
# Test Pool's _check_closed() method doesn't close a healthy socket.
|
||||
cx_pool = self.create_pool(max_pool_size=10)
|
||||
cx_pool._check_interval_seconds = 0 # Always check.
|
||||
with cx_pool.get_socket({}) as sock_info:
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
pass
|
||||
|
||||
with cx_pool.get_socket({}) as new_sock_info:
|
||||
with cx_pool.get_socket() as new_sock_info:
|
||||
self.assertEqual(sock_info, new_sock_info)
|
||||
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
@ -208,11 +208,11 @@ class TestPooling(_TestPoolingBase):
|
||||
# get_socket() returns socket after a non-network error.
|
||||
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
with cx_pool.get_socket({}) as sock_info:
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
1 / 0
|
||||
|
||||
# Socket was returned, not closed.
|
||||
with cx_pool.get_socket({}) as new_sock_info:
|
||||
with cx_pool.get_socket() as new_sock_info:
|
||||
self.assertEqual(sock_info, new_sock_info)
|
||||
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
@ -221,7 +221,7 @@ class TestPooling(_TestPoolingBase):
|
||||
# Test that Pool removes explicitly closed socket.
|
||||
cx_pool = self.create_pool()
|
||||
|
||||
with cx_pool.get_socket({}) as sock_info:
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
# Use SocketInfo's API to close the socket.
|
||||
sock_info.close_socket(None)
|
||||
|
||||
@ -233,20 +233,20 @@ class TestPooling(_TestPoolingBase):
|
||||
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
|
||||
cx_pool._check_interval_seconds = 0 # Always check.
|
||||
|
||||
with cx_pool.get_socket({}) as sock_info:
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
# Simulate a closed socket without telling the SocketInfo it's
|
||||
# closed.
|
||||
sock_info.sock.close()
|
||||
self.assertTrue(sock_info.socket_closed())
|
||||
|
||||
with cx_pool.get_socket({}) as new_sock_info:
|
||||
with cx_pool.get_socket() as new_sock_info:
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
self.assertNotEqual(sock_info, new_sock_info)
|
||||
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
|
||||
# Semaphore was released.
|
||||
with cx_pool.get_socket({}):
|
||||
with cx_pool.get_socket():
|
||||
pass
|
||||
|
||||
def test_socket_closed(self):
|
||||
@ -290,7 +290,7 @@ class TestPooling(_TestPoolingBase):
|
||||
|
||||
def test_return_socket_after_reset(self):
|
||||
pool = self.create_pool()
|
||||
with pool.get_socket({}) as sock:
|
||||
with pool.get_socket() as sock:
|
||||
self.assertEqual(pool.active_sockets, 1)
|
||||
self.assertEqual(pool.operation_count, 1)
|
||||
pool.reset()
|
||||
@ -309,7 +309,7 @@ class TestPooling(_TestPoolingBase):
|
||||
cx_pool._check_interval_seconds = 0 # Always check.
|
||||
self.addCleanup(cx_pool.close)
|
||||
|
||||
with cx_pool.get_socket({}) as sock_info:
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
# Simulate a closed socket without telling the SocketInfo it's
|
||||
# closed.
|
||||
sock_info.sock.close()
|
||||
@ -317,12 +317,12 @@ class TestPooling(_TestPoolingBase):
|
||||
# Swap pool's address with a bad one.
|
||||
address, cx_pool.address = cx_pool.address, ('foo.com', 1234)
|
||||
with self.assertRaises(AutoReconnect):
|
||||
with cx_pool.get_socket({}):
|
||||
with cx_pool.get_socket():
|
||||
pass
|
||||
|
||||
# Back to normal, semaphore was correctly released.
|
||||
cx_pool.address = address
|
||||
with cx_pool.get_socket({}):
|
||||
with cx_pool.get_socket():
|
||||
pass
|
||||
|
||||
def test_wait_queue_timeout(self):
|
||||
@ -331,10 +331,10 @@ class TestPooling(_TestPoolingBase):
|
||||
max_pool_size=1, wait_queue_timeout=wait_queue_timeout)
|
||||
self.addCleanup(pool.close)
|
||||
|
||||
with pool.get_socket({}) as sock_info:
|
||||
with pool.get_socket() as sock_info:
|
||||
start = time.time()
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
with pool.get_socket({}):
|
||||
with pool.get_socket():
|
||||
pass
|
||||
|
||||
duration = time.time() - start
|
||||
@ -349,7 +349,7 @@ class TestPooling(_TestPoolingBase):
|
||||
self.addCleanup(pool.close)
|
||||
|
||||
# Reach max_size.
|
||||
with pool.get_socket({}) as s1:
|
||||
with pool.get_socket() as s1:
|
||||
t = SocketGetter(self.c, pool)
|
||||
t.start()
|
||||
while t.state != 'get_socket':
|
||||
@ -370,7 +370,7 @@ class TestPooling(_TestPoolingBase):
|
||||
socks = []
|
||||
for _ in range(2):
|
||||
# Call 'pin_cursor' so we can hold the socket.
|
||||
with pool.get_socket({}) as sock:
|
||||
with pool.get_socket() as sock:
|
||||
sock.pin_cursor()
|
||||
socks.append(sock)
|
||||
|
||||
@ -515,7 +515,7 @@ class TestPoolMaxSize(_TestPoolingBase):
|
||||
# socket from pool" instead of AutoReconnect.
|
||||
for i in range(2):
|
||||
with self.assertRaises(AutoReconnect) as context:
|
||||
with test_pool.get_socket({}):
|
||||
with test_pool.get_socket():
|
||||
pass
|
||||
|
||||
# Testing for AutoReconnect instead of ConnectionFailure, above,
|
||||
|
||||
@ -269,7 +269,7 @@ class MockPool(object):
|
||||
def stale_generation(self, gen, service_id):
|
||||
return self.gen.stale(gen, service_id)
|
||||
|
||||
def get_socket(self, all_credentials, handler=None):
|
||||
def get_socket(self, handler=None):
|
||||
return MockSocketInfo()
|
||||
|
||||
def return_socket(self, *args, **kwargs):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user