PYTHON-2585 Remove legacy multi-auth code (#816)

This commit is contained in:
Shane Harvey 2021-12-09 18:00:41 -08:00 committed by GitHub
parent 7bd9bd7b47
commit c94a3ad1df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 106 additions and 167 deletions

View File

@ -195,7 +195,7 @@ def _authenticate_scram(credentials, sock_info, mechanism):
# Make local
_hmac = hmac.HMAC
ctx = sock_info.auth_ctx.get(credentials)
ctx = sock_info.auth_ctx
if ctx and ctx.speculate_succeeded():
nonce, first_bare = ctx.scram_data
res = ctx.speculative_authenticate
@ -424,7 +424,7 @@ def _authenticate_plain(credentials, sock_info):
def _authenticate_x509(credentials, sock_info):
"""Authenticate using MONGODB-X509.
"""
ctx = sock_info.auth_ctx.get(credentials)
ctx = sock_info.auth_ctx
if ctx and ctx.speculate_succeeded():
# MONGODB-X509 is done after the speculative auth step.
return
@ -454,8 +454,8 @@ def _authenticate_mongo_cr(credentials, sock_info):
def _authenticate_default(credentials, sock_info):
if sock_info.max_wire_version >= 7:
if credentials in sock_info.negotiated_mechanisms:
mechs = sock_info.negotiated_mechanisms[credentials]
if sock_info.negotiated_mechs:
mechs = sock_info.negotiated_mechs
else:
source = credentials.source
cmd = sock_info.hello_cmd()

View File

@ -117,8 +117,9 @@ def _parse_ssl_options(options):
return None, allow_invalid_hostnames
def _parse_pool_options(options):
def _parse_pool_options(username, password, database, options):
"""Parse connection pool options."""
credentials = _parse_credentials(username, password, database, options)
max_pool_size = options.get('maxpoolsize', common.MAX_POOL_SIZE)
min_pool_size = options.get('minpoolsize', common.MIN_POOL_SIZE)
max_idle_time_seconds = options.get(
@ -151,7 +152,8 @@ def _parse_pool_options(options):
compression_settings,
max_connecting=max_connecting,
server_api=server_api,
load_balanced=load_balanced)
load_balanced=load_balanced,
credentials=credentials)
class ClientOptions(object):
@ -164,10 +166,7 @@ class ClientOptions(object):
def __init__(self, username, password, database, options):
self.__options = options
self.__codec_options = _parse_codec_options(options)
self.__credentials = _parse_credentials(
username, password, database, options)
self.__direct_connection = options.get('directconnection')
self.__local_threshold_ms = options.get(
'localthresholdms', common.LOCAL_THRESHOLD_MS)
@ -175,7 +174,8 @@ class ClientOptions(object):
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
self.__server_selection_timeout = options.get(
'serverselectiontimeoutms', common.SERVER_SELECTION_TIMEOUT)
self.__pool_options = _parse_pool_options(options)
self.__pool_options = _parse_pool_options(
username, password, database, options)
self.__read_preference = _parse_read_preference(options)
self.__replica_set_name = options.get('replicaset')
self.__write_concern = _parse_write_concern(options)
@ -205,11 +205,6 @@ class ClientOptions(object):
"""A :class:`~bson.codec_options.CodecOptions` instance."""
return self.__codec_options
@property
def _credentials(self):
"""A :class:`~pymongo.auth.MongoCredentials` instance or None."""
return self.__credentials
@property
def direct_connection(self):
"""Whether to connect to the deployment in 'Single' topology."""

View File

@ -729,11 +729,6 @@ class MongoClient(common.BaseObject):
options.write_concern,
options.read_concern)
self.__all_credentials = {}
creds = options._credentials
if creds:
self.__all_credentials[creds.source] = creds
self._topology_settings = TopologySettings(
seeds=seeds,
replica_set_name=options.replica_set_name,
@ -1090,8 +1085,7 @@ class MongoClient(common.BaseObject):
if in_txn and session._pinned_connection:
yield session._pinned_connection
return
with server.get_socket(
self.__all_credentials, handler=err_handler) as sock_info:
with server.get_socket(handler=err_handler) as sock_info:
# Pin this session to the selected server or connection.
if (in_txn and server.description.server_type in (
SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer)):
@ -1535,7 +1529,7 @@ class MongoClient(common.BaseObject):
maintain connection pool parameters."""
try:
self._process_kill_cursors()
self._topology.update_pool(self.__all_credentials)
self._topology.update_pool()
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
return

View File

@ -246,7 +246,7 @@ class Monitor(MonitorBase):
if self._cancel_context and self._cancel_context.cancelled:
self._reset_connection()
with self._pool.get_socket({}) as sock_info:
with self._pool.get_socket() as sock_info:
self._cancel_context = sock_info.cancel_context
response, round_trip_time = self._check_with_socket(sock_info)
if not response.awaitable:
@ -275,11 +275,10 @@ class Monitor(MonitorBase):
response = conn._hello(
cluster_time,
self._server_description.topology_version,
self._settings.heartbeat_frequency,
None)
self._settings.heartbeat_frequency)
else:
# New connection handshake or polling hello (MongoDB <4.4).
response = conn._hello(cluster_time, None, None, None)
response = conn._hello(cluster_time, None, None)
return response, time.monotonic() - start
@ -388,7 +387,7 @@ class _RttMonitor(MonitorBase):
def _ping(self):
"""Run a "hello" command and return the RTT."""
with self._pool.get_socket({}) as sock_info:
with self._pool.get_socket() as sock_info:
if self._executor._stopped:
raise Exception('_RttMonitor closed')
start = time.monotonic()

View File

@ -275,7 +275,8 @@ class PoolOptions(object):
'__ssl_context', '__tls_allow_invalid_hostnames',
'__event_listeners', '__appname', '__driver', '__metadata',
'__compression_settings', '__max_connecting',
'__pause_enabled', '__server_api', '__load_balanced')
'__pause_enabled', '__server_api', '__load_balanced',
'__credentials')
def __init__(self, max_pool_size=MAX_POOL_SIZE,
min_pool_size=MIN_POOL_SIZE,
@ -285,7 +286,8 @@ class PoolOptions(object):
tls_allow_invalid_hostnames=False,
event_listeners=None, appname=None, driver=None,
compression_settings=None, max_connecting=MAX_CONNECTING,
pause_enabled=True, server_api=None, load_balanced=None):
pause_enabled=True, server_api=None, load_balanced=None,
credentials=None):
self.__max_pool_size = max_pool_size
self.__min_pool_size = min_pool_size
self.__max_idle_time_seconds = max_idle_time_seconds
@ -302,6 +304,7 @@ class PoolOptions(object):
self.__pause_enabled = pause_enabled
self.__server_api = server_api
self.__load_balanced = load_balanced
self.__credentials = credentials
self.__metadata = copy.deepcopy(_METADATA)
if appname:
self.__metadata['application'] = {'name': appname}
@ -325,6 +328,11 @@ class PoolOptions(object):
self.__metadata['platform'] = "%s|%s" % (
_METADATA['platform'], driver.platform)
@property
def _credentials(self):
"""A :class:`~pymongo.auth.MongoCredentials` instance or None."""
return self.__credentials
@property
def non_default_options(self):
"""The non-default options this pool was created with.
@ -457,25 +465,6 @@ class PoolOptions(object):
return self.__load_balanced
def _negotiate_creds(all_credentials):
"""Return one credential that needs mechanism negotiation, if any.
"""
if all_credentials:
for creds in all_credentials.values():
if creds.mechanism == 'DEFAULT' and creds.username:
return creds
return None
def _speculative_context(all_credentials):
"""Return the _AuthContext to use for speculative auth, if any.
"""
if all_credentials and len(all_credentials) == 1:
creds = next(iter(all_credentials.values()))
return auth._AuthContext.from_credentials(creds)
return None
class _CancellationContext(object):
def __init__(self):
self._cancelled = False
@ -504,7 +493,7 @@ class SocketInfo(object):
self.sock = sock
self.address = address
self.id = id
self.authset = set()
self.authed = set()
self.closed = False
self.last_checkin_time = time.monotonic()
self.performed_handshake = False
@ -523,9 +512,8 @@ class SocketInfo(object):
self.compression_context = None
self.socket_checker = SocketChecker()
# Support for mechanism negotiation on the initial handshake.
# Maps credential to saslSupportedMechs.
self.negotiated_mechanisms = {}
self.auth_ctx = {}
self.negotiated_mechs = None
self.auth_ctx = None
# The pool's generation changes with each reset() so we can close
# sockets created before the last reset.
@ -567,11 +555,10 @@ class SocketInfo(object):
else:
return SON([(HelloCompat.LEGACY_CMD, 1), ('helloOk', True)])
def hello(self, all_credentials=None):
return self._hello(None, None, None, all_credentials)
def hello(self):
return self._hello(None, None, None)
def _hello(self, cluster_time, topology_version,
heartbeat_frequency, all_credentials):
def _hello(self, cluster_time, topology_version, heartbeat_frequency):
cmd = self.hello_cmd()
performing_handshake = not self.performed_handshake
awaitable = False
@ -594,14 +581,15 @@ class SocketInfo(object):
if not performing_handshake and cluster_time is not None:
cmd['$clusterTime'] = cluster_time
# XXX: Simplify in PyMongo 4.0 when all_credentials is always a single
# unchangeable value per MongoClient.
creds = _negotiate_creds(all_credentials)
creds = self.opts._credentials
if creds:
cmd['saslSupportedMechs'] = creds.source + '.' + creds.username
auth_ctx = _speculative_context(all_credentials)
if auth_ctx:
cmd['speculativeAuthenticate'] = auth_ctx.speculate_command()
if creds.mechanism == 'DEFAULT' and creds.username:
cmd['saslSupportedMechs'] = creds.source + '.' + creds.username
auth_ctx = auth._AuthContext.from_credentials(creds)
if auth_ctx:
cmd['speculativeAuthenticate'] = auth_ctx.speculate_command()
else:
auth_ctx = None
doc = self.command('admin', cmd, publish_events=False,
exhaust_allowed=awaitable)
@ -628,11 +616,11 @@ class SocketInfo(object):
self.op_msg_enabled = True
if creds:
self.negotiated_mechanisms[creds] = hello.sasl_supported_mechs
self.negotiated_mechs = hello.sasl_supported_mechs
if auth_ctx:
auth_ctx.parse_response(hello)
if auth_ctx.speculate_succeeded():
self.auth_ctx[auth_ctx.credentials] = auth_ctx
self.auth_ctx = auth_ctx
if self.opts.load_balanced:
if not hello.service_id:
raise ConfigurationError(
@ -799,41 +787,21 @@ class SocketInfo(object):
helpers._check_command_response(result, self.max_wire_version)
return result
def check_auth(self, all_credentials):
"""Update this socket's authentication.
def authenticate(self):
"""Authenticate to the server if needed.
Log in or out to bring this socket's credentials up to date with
those provided. Can raise ConnectionFailure or OperationFailure.
:Parameters:
- `all_credentials`: dict, maps auth source to MongoCredential.
Can raise ConnectionFailure or OperationFailure.
"""
if all_credentials:
for credentials in all_credentials.values():
if credentials not in self.authset:
self.authenticate(credentials)
# CMAP spec says to publish the ready event only after authenticating
# the connection.
if not self.ready:
creds = self.opts._credentials
if creds:
auth.authenticate(creds, self)
self.ready = True
if self.enabled_for_cmap:
self.listeners.publish_connection_ready(self.address, self.id)
def authenticate(self, credentials):
"""Log in to the server and store these credentials in `authset`.
Can raise ConnectionFailure or OperationFailure.
:Parameters:
- `credentials`: A MongoCredential.
"""
auth.authenticate(credentials, self)
self.authset.add(credentials)
# negotiated_mechanisms are no longer needed.
self.negotiated_mechanisms.pop(credentials, None)
self.auth_ctx.pop(credentials, None)
def validate_session(self, client, session):
"""Validate this session before use with client.
@ -1245,7 +1213,7 @@ class Pool:
def stale_generation(self, gen, service_id):
return self.gen.stale(gen, service_id)
def remove_stale_sockets(self, reference_generation, all_credentials):
def remove_stale_sockets(self, reference_generation):
"""Removes stale sockets then adds new ones if pool is too small and
has not been reset. The `reference_generation` argument specifies the
`generation` at the point in time this operation was requested on the
@ -1281,7 +1249,7 @@ class Pool:
return
self._pending += 1
incremented = True
sock_info = self.connect(all_credentials)
sock_info = self.connect()
with self.lock:
# Close connection and return if the pool was reset during
# socket creation or while acquiring the pool lock.
@ -1300,7 +1268,7 @@ class Pool:
self.requests -= 1
self.size_cond.notify()
def connect(self, all_credentials=None):
def connect(self):
"""Connect to Mongo and return a new SocketInfo.
Can raise ConnectionFailure.
@ -1331,10 +1299,10 @@ class Pool:
sock_info = SocketInfo(sock, self, self.address, conn_id)
try:
if self.handshake:
sock_info.hello(all_credentials)
sock_info.hello()
self.is_writable = sock_info.is_writable
sock_info.check_auth(all_credentials)
sock_info.authenticate()
except BaseException:
sock_info.close_socket(ConnectionClosedReason.ERROR)
raise
@ -1342,7 +1310,7 @@ class Pool:
return sock_info
@contextlib.contextmanager
def get_socket(self, all_credentials, handler=None):
def get_socket(self, handler=None):
"""Get a socket from the pool. Use with a "with" statement.
Returns a :class:`SocketInfo` object wrapping a connected
@ -1350,25 +1318,20 @@ class Pool:
This method should always be used in a with-statement::
with pool.get_socket(credentials) as socket_info:
with pool.get_socket() as socket_info:
socket_info.send_message(msg)
data = socket_info.receive_message(op_code, request_id)
The socket is logged in or out as needed to match ``all_credentials``
using the correct authentication mechanism for the server's wire
protocol version.
Can raise ConnectionFailure or OperationFailure.
:Parameters:
- `all_credentials`: dict, maps auth source to MongoCredential.
- `handler` (optional): A _MongoClientErrorHandler.
"""
listeners = self.opts._event_listeners
if self.enabled_for_cmap:
listeners.publish_connection_check_out_started(self.address)
sock_info = self._get_socket(all_credentials)
sock_info = self._get_socket()
if self.enabled_for_cmap:
listeners.publish_connection_checked_out(
self.address, sock_info.id)
@ -1407,7 +1370,7 @@ class Pool:
_raise_connection_failure(
self.address, AutoReconnect('connection pool paused'))
def _get_socket(self, all_credentials):
def _get_socket(self):
"""Get or create a SocketInfo. Can raise ConnectionFailure."""
# We use the pid here to avoid issues with fork / multiprocessing.
# See test.test_client:TestClient.test_fork for an example of
@ -1480,12 +1443,11 @@ class Pool:
continue
else: # We need to create a new connection
try:
sock_info = self.connect(all_credentials)
sock_info = self.connect()
finally:
with self._max_connecting_cond:
self._pending -= 1
self._max_connecting_cond.notify()
sock_info.check_auth(all_credentials)
except BaseException:
if sock_info:
# We checked out a socket but authentication failed.

View File

@ -77,9 +77,9 @@ class Server(object):
Can raise ConnectionFailure, OperationFailure, etc.
:Parameters:
- `sock_info` - A SocketInfo instance.
- `operation`: A _Query or _GetMore object.
- `set_secondary_okay`: Pass to operation.get_message.
- `all_credentials`: dict, maps auth source to MongoCredential.
- `listeners`: Instance of _EventListeners or None.
- `unpack_res`: A callable that decodes the wire protocol response.
"""
@ -200,8 +200,8 @@ class Server(object):
return response
def get_socket(self, all_credentials, handler=None):
return self.pool.get_socket(all_credentials, handler)
def get_socket(self, handler=None):
return self.pool.get_socket(handler)
@property
def description(self):

View File

@ -444,7 +444,7 @@ class Topology(object):
return self._description.known_servers
return self._description.readable_servers
def update_pool(self, all_credentials):
def update_pool(self):
# Remove any stale sockets and add new sockets if pool is too small.
servers = []
with self._lock:
@ -456,7 +456,7 @@ class Topology(object):
for server, generation in servers:
try:
server.pool.remove_stale_sockets(generation, all_credentials)
server.pool.remove_stale_sockets(generation)
except PyMongoError as exc:
ctx = _ErrorContext(exc, 0, generation, False, None)
self.handle_error(server.description.address, ctx)

View File

@ -40,7 +40,7 @@ class MockPool(Pool):
Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs)
@contextlib.contextmanager
def get_socket(self, all_credentials, handler=None):
def get_socket(self, handler=None):
client = self.client
host_and_port = '%s:%s' % (self.mock_host, self.mock_port)
if host_and_port in client.mock_down_hosts:
@ -51,7 +51,7 @@ class MockPool(Pool):
+ client.mock_members
+ client.mock_mongoses), "bad host: %s" % host_and_port
with Pool.get_socket(self, all_credentials, handler) as sock_info:
with Pool.get_socket(self, handler) as sock_info:
sock_info.mock_host = self.mock_host
sock_info.mock_port = self.mock_port
yield sock_info

View File

@ -30,6 +30,7 @@ from pymongo.read_preferences import ReadPreference
from pymongo.saslprep import HAVE_STRINGPREP
from test import client_context, IntegrationTest, SkipTest, unittest, Version
from test.utils import (delay,
get_pool,
ignore_deprecations,
single_client,
rs_or_single_client,
@ -521,10 +522,12 @@ class TestSCRAM(IntegrationTest):
def test_cache(self):
client = single_client()
credentials = client.options.pool_options._credentials
cache = credentials.cache
self.assertIsNotNone(cache)
self.assertIsNone(cache.data)
# Force authentication.
client.admin.command('ping')
all_credentials = client._MongoClient__all_credentials
credentials = all_credentials.get('admin')
cache = credentials.cache
self.assertIsNotNone(cache)
data = cache.data
@ -536,19 +539,6 @@ class TestSCRAM(IntegrationTest):
self.assertIsInstance(salt, bytes)
self.assertIsInstance(iterations, int)
pool = next(iter(client._topology._servers.values()))._pool
with pool.get_socket(all_credentials) as sock_info:
authset = sock_info.authset
cached = set(all_credentials.values())
self.assertEqual(len(cached), 1)
self.assertFalse(authset - cached)
self.assertFalse(cached - authset)
sock_credentials = next(iter(authset))
sock_cache = sock_credentials.cache
self.assertIsNotNone(sock_cache)
self.assertEqual(sock_cache.data, data)
def test_scram_threaded(self):
coll = client_context.client.db.test

View File

@ -44,7 +44,7 @@ def create_test(test_case):
self.assertRaises(Exception, MongoClient, uri, connect=False)
else:
client = MongoClient(uri, connect=False)
credentials = client._MongoClient__options._credentials
credentials = client.options.pool_options._credentials
if credential is None:
self.assertIsNone(credentials)
else:

View File

@ -498,7 +498,7 @@ class TestClient(IntegrationTest):
client = rs_or_single_client()
server = client._get_topology().select_server(
readable_server_selector)
with server._pool.get_socket({}) as sock_info:
with server._pool.get_socket() as sock_info:
pass
self.assertEqual(1, len(server._pool.sockets))
self.assertTrue(sock_info in server._pool.sockets)
@ -511,7 +511,7 @@ class TestClient(IntegrationTest):
minPoolSize=1)
server = client._get_topology().select_server(
readable_server_selector)
with server._pool.get_socket({}) as sock_info:
with server._pool.get_socket() as sock_info:
pass
# When the reaper runs at the same time as the get_socket, two
# sockets could be created and checked into the pool.
@ -530,7 +530,7 @@ class TestClient(IntegrationTest):
maxPoolSize=1)
server = client._get_topology().select_server(
readable_server_selector)
with server._pool.get_socket({}) as sock_info:
with server._pool.get_socket() as sock_info:
pass
# When the reaper runs at the same time as the get_socket,
# maxPoolSize=1 should prevent two sockets from being created.
@ -547,11 +547,11 @@ class TestClient(IntegrationTest):
client = rs_or_single_client(maxIdleTimeMS=500)
server = client._get_topology().select_server(
readable_server_selector)
with server._pool.get_socket({}) as sock_info_one:
with server._pool.get_socket() as sock_info_one:
pass
# Assert that the pool does not close sockets prematurely.
time.sleep(.300)
with server._pool.get_socket({}) as sock_info_two:
with server._pool.get_socket() as sock_info_two:
pass
self.assertIs(sock_info_one, sock_info_two)
wait_until(
@ -574,7 +574,7 @@ class TestClient(IntegrationTest):
"pool initialized with 10 sockets")
# Assert that if a socket is closed, a new one takes its place
with server._pool.get_socket({}) as sock_info:
with server._pool.get_socket() as sock_info:
sock_info.close_socket(None)
wait_until(lambda: 10 == len(server._pool.sockets),
"a closed socket gets replaced from the pool")
@ -586,12 +586,12 @@ class TestClient(IntegrationTest):
client = rs_or_single_client(maxIdleTimeMS=500)
server = client._get_topology().select_server(
readable_server_selector)
with server._pool.get_socket({}) as sock_info:
with server._pool.get_socket() as sock_info:
pass
self.assertEqual(1, len(server._pool.sockets))
time.sleep(1) # Sleep so that the socket becomes stale.
with server._pool.get_socket({}) as new_sock_info:
with server._pool.get_socket() as new_sock_info:
self.assertNotEqual(sock_info, new_sock_info)
self.assertEqual(1, len(server._pool.sockets))
self.assertFalse(sock_info in server._pool.sockets)
@ -601,11 +601,11 @@ class TestClient(IntegrationTest):
client = rs_or_single_client()
server = client._get_topology().select_server(
readable_server_selector)
with server._pool.get_socket({}) as sock_info:
with server._pool.get_socket() as sock_info:
pass
self.assertEqual(1, len(server._pool.sockets))
time.sleep(1)
with server._pool.get_socket({}) as new_sock_info:
with server._pool.get_socket() as new_sock_info:
self.assertEqual(sock_info, new_sock_info)
self.assertEqual(1, len(server._pool.sockets))
@ -1106,7 +1106,7 @@ class TestClient(IntegrationTest):
def test_socketKeepAlive(self):
pool = get_pool(self.client)
with pool.get_socket({}) as sock_info:
with pool.get_socket() as sock_info:
keepalive = sock_info.sock.getsockopt(socket.SOL_SOCKET,
socket.SO_KEEPALIVE)
self.assertTrue(keepalive)
@ -1325,8 +1325,8 @@ class TestClient(IntegrationTest):
socket_info = one(pool.sockets)
socket_info.sock.close()
# SocketInfo.check_auth logs in with the new credential, but gets a
# socket.error. Should be reraised as AutoReconnect.
# SocketInfo.authenticate logs, but gets a socket.error. Should be
# reraised as AutoReconnect.
self.assertRaises(AutoReconnect, c.test.collection.find_one)
# No semaphore leak, the pool is allowed to make a new socket.
@ -1521,8 +1521,7 @@ class TestClient(IntegrationTest):
try:
while True:
for _ in range(10):
client._topology.update_pool(
client._MongoClient__all_credentials)
client._topology.update_pool()
if generation != pool.gen.get_overall():
break
finally:

View File

@ -120,7 +120,7 @@ class TestCMAP(IntegrationTest):
def check_out(self, op):
"""Run the 'checkOut' operation."""
label = op['label']
with self.pool.get_socket({}) as sock_info:
with self.pool.get_socket() as sock_info:
# Call 'pin_cursor' so we can hold the socket.
sock_info.pin_cursor()
if label:
@ -452,7 +452,7 @@ class TestCMAP(IntegrationTest):
self.assertEqual(1, listener.event_count(PoolClearedEvent))
self.assertEqual(PoolState.READY, pool.state)
# Checking out a connection should succeed
with pool.get_socket({}):
with pool.get_socket():
pass

View File

@ -118,7 +118,7 @@ class SocketGetter(MongoThread):
self.state = 'get_socket'
# Call 'pin_cursor' so we can hold the socket.
with self.pool.get_socket({}) as sock:
with self.pool.get_socket() as sock:
sock.pin_cursor()
self.sock = sock
@ -196,10 +196,10 @@ class TestPooling(_TestPoolingBase):
# Test Pool's _check_closed() method doesn't close a healthy socket.
cx_pool = self.create_pool(max_pool_size=10)
cx_pool._check_interval_seconds = 0 # Always check.
with cx_pool.get_socket({}) as sock_info:
with cx_pool.get_socket() as sock_info:
pass
with cx_pool.get_socket({}) as new_sock_info:
with cx_pool.get_socket() as new_sock_info:
self.assertEqual(sock_info, new_sock_info)
self.assertEqual(1, len(cx_pool.sockets))
@ -208,11 +208,11 @@ class TestPooling(_TestPoolingBase):
# get_socket() returns socket after a non-network error.
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
with self.assertRaises(ZeroDivisionError):
with cx_pool.get_socket({}) as sock_info:
with cx_pool.get_socket() as sock_info:
1 / 0
# Socket was returned, not closed.
with cx_pool.get_socket({}) as new_sock_info:
with cx_pool.get_socket() as new_sock_info:
self.assertEqual(sock_info, new_sock_info)
self.assertEqual(1, len(cx_pool.sockets))
@ -221,7 +221,7 @@ class TestPooling(_TestPoolingBase):
# Test that Pool removes explicitly closed socket.
cx_pool = self.create_pool()
with cx_pool.get_socket({}) as sock_info:
with cx_pool.get_socket() as sock_info:
# Use SocketInfo's API to close the socket.
sock_info.close_socket(None)
@ -233,20 +233,20 @@ class TestPooling(_TestPoolingBase):
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
cx_pool._check_interval_seconds = 0 # Always check.
with cx_pool.get_socket({}) as sock_info:
with cx_pool.get_socket() as sock_info:
# Simulate a closed socket without telling the SocketInfo it's
# closed.
sock_info.sock.close()
self.assertTrue(sock_info.socket_closed())
with cx_pool.get_socket({}) as new_sock_info:
with cx_pool.get_socket() as new_sock_info:
self.assertEqual(0, len(cx_pool.sockets))
self.assertNotEqual(sock_info, new_sock_info)
self.assertEqual(1, len(cx_pool.sockets))
# Semaphore was released.
with cx_pool.get_socket({}):
with cx_pool.get_socket():
pass
def test_socket_closed(self):
@ -290,7 +290,7 @@ class TestPooling(_TestPoolingBase):
def test_return_socket_after_reset(self):
pool = self.create_pool()
with pool.get_socket({}) as sock:
with pool.get_socket() as sock:
self.assertEqual(pool.active_sockets, 1)
self.assertEqual(pool.operation_count, 1)
pool.reset()
@ -309,7 +309,7 @@ class TestPooling(_TestPoolingBase):
cx_pool._check_interval_seconds = 0 # Always check.
self.addCleanup(cx_pool.close)
with cx_pool.get_socket({}) as sock_info:
with cx_pool.get_socket() as sock_info:
# Simulate a closed socket without telling the SocketInfo it's
# closed.
sock_info.sock.close()
@ -317,12 +317,12 @@ class TestPooling(_TestPoolingBase):
# Swap pool's address with a bad one.
address, cx_pool.address = cx_pool.address, ('foo.com', 1234)
with self.assertRaises(AutoReconnect):
with cx_pool.get_socket({}):
with cx_pool.get_socket():
pass
# Back to normal, semaphore was correctly released.
cx_pool.address = address
with cx_pool.get_socket({}):
with cx_pool.get_socket():
pass
def test_wait_queue_timeout(self):
@ -331,10 +331,10 @@ class TestPooling(_TestPoolingBase):
max_pool_size=1, wait_queue_timeout=wait_queue_timeout)
self.addCleanup(pool.close)
with pool.get_socket({}) as sock_info:
with pool.get_socket() as sock_info:
start = time.time()
with self.assertRaises(ConnectionFailure):
with pool.get_socket({}):
with pool.get_socket():
pass
duration = time.time() - start
@ -349,7 +349,7 @@ class TestPooling(_TestPoolingBase):
self.addCleanup(pool.close)
# Reach max_size.
with pool.get_socket({}) as s1:
with pool.get_socket() as s1:
t = SocketGetter(self.c, pool)
t.start()
while t.state != 'get_socket':
@ -370,7 +370,7 @@ class TestPooling(_TestPoolingBase):
socks = []
for _ in range(2):
# Call 'pin_cursor' so we can hold the socket.
with pool.get_socket({}) as sock:
with pool.get_socket() as sock:
sock.pin_cursor()
socks.append(sock)
@ -515,7 +515,7 @@ class TestPoolMaxSize(_TestPoolingBase):
# socket from pool" instead of AutoReconnect.
for i in range(2):
with self.assertRaises(AutoReconnect) as context:
with test_pool.get_socket({}):
with test_pool.get_socket():
pass
# Testing for AutoReconnect instead of ConnectionFailure, above,

View File

@ -269,7 +269,7 @@ class MockPool(object):
def stale_generation(self, gen, service_id):
return self.gen.stale(gen, service_id)
def get_socket(self, all_credentials, handler=None):
def get_socket(self, handler=None):
return MockSocketInfo()
def return_socket(self, *args, **kwargs):