diff --git a/pymongo/auth.py b/pymongo/auth.py index 17f3a32fe..a2e206357 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -195,7 +195,7 @@ def _authenticate_scram(credentials, sock_info, mechanism): # Make local _hmac = hmac.HMAC - ctx = sock_info.auth_ctx.get(credentials) + ctx = sock_info.auth_ctx if ctx and ctx.speculate_succeeded(): nonce, first_bare = ctx.scram_data res = ctx.speculative_authenticate @@ -424,7 +424,7 @@ def _authenticate_plain(credentials, sock_info): def _authenticate_x509(credentials, sock_info): """Authenticate using MONGODB-X509. """ - ctx = sock_info.auth_ctx.get(credentials) + ctx = sock_info.auth_ctx if ctx and ctx.speculate_succeeded(): # MONGODB-X509 is done after the speculative auth step. return @@ -454,8 +454,8 @@ def _authenticate_mongo_cr(credentials, sock_info): def _authenticate_default(credentials, sock_info): if sock_info.max_wire_version >= 7: - if credentials in sock_info.negotiated_mechanisms: - mechs = sock_info.negotiated_mechanisms[credentials] + if sock_info.negotiated_mechs: + mechs = sock_info.negotiated_mechs else: source = credentials.source cmd = sock_info.hello_cmd() diff --git a/pymongo/client_options.py b/pymongo/client_options.py index f7dbf255b..c2f5ae01c 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -117,8 +117,9 @@ def _parse_ssl_options(options): return None, allow_invalid_hostnames -def _parse_pool_options(options): +def _parse_pool_options(username, password, database, options): """Parse connection pool options.""" + credentials = _parse_credentials(username, password, database, options) max_pool_size = options.get('maxpoolsize', common.MAX_POOL_SIZE) min_pool_size = options.get('minpoolsize', common.MIN_POOL_SIZE) max_idle_time_seconds = options.get( @@ -151,7 +152,8 @@ def _parse_pool_options(options): compression_settings, max_connecting=max_connecting, server_api=server_api, - load_balanced=load_balanced) + load_balanced=load_balanced, + credentials=credentials) class ClientOptions(object): @@ -164,10 +166,7 @@ class ClientOptions(object): def __init__(self, username, password, database, options): self.__options = options - self.__codec_options = _parse_codec_options(options) - self.__credentials = _parse_credentials( - username, password, database, options) self.__direct_connection = options.get('directconnection') self.__local_threshold_ms = options.get( 'localthresholdms', common.LOCAL_THRESHOLD_MS) @@ -175,7 +174,8 @@ class ClientOptions(object): # common.SERVER_SELECTION_TIMEOUT because it is set directly by tests. self.__server_selection_timeout = options.get( 'serverselectiontimeoutms', common.SERVER_SELECTION_TIMEOUT) - self.__pool_options = _parse_pool_options(options) + self.__pool_options = _parse_pool_options( + username, password, database, options) self.__read_preference = _parse_read_preference(options) self.__replica_set_name = options.get('replicaset') self.__write_concern = _parse_write_concern(options) @@ -205,11 +205,6 @@ class ClientOptions(object): """A :class:`~bson.codec_options.CodecOptions` instance.""" return self.__codec_options - @property - def _credentials(self): - """A :class:`~pymongo.auth.MongoCredentials` instance or None.""" - return self.__credentials - @property def direct_connection(self): """Whether to connect to the deployment in 'Single' topology.""" diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 9c98e5d21..87c87c024 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -729,11 +729,6 @@ class MongoClient(common.BaseObject): options.write_concern, options.read_concern) - self.__all_credentials = {} - creds = options._credentials - if creds: - self.__all_credentials[creds.source] = creds - self._topology_settings = TopologySettings( seeds=seeds, replica_set_name=options.replica_set_name, @@ -1090,8 +1085,7 @@ class MongoClient(common.BaseObject): if in_txn and session._pinned_connection: yield session._pinned_connection return - with server.get_socket( - self.__all_credentials, handler=err_handler) as sock_info: + with server.get_socket(handler=err_handler) as sock_info: # Pin this session to the selected server or connection. if (in_txn and server.description.server_type in ( SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer)): @@ -1535,7 +1529,7 @@ class MongoClient(common.BaseObject): maintain connection pool parameters.""" try: self._process_kill_cursors() - self._topology.update_pool(self.__all_credentials) + self._topology.update_pool() except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: return diff --git a/pymongo/monitor.py b/pymongo/monitor.py index a383e272c..039ec5194 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -246,7 +246,7 @@ class Monitor(MonitorBase): if self._cancel_context and self._cancel_context.cancelled: self._reset_connection() - with self._pool.get_socket({}) as sock_info: + with self._pool.get_socket() as sock_info: self._cancel_context = sock_info.cancel_context response, round_trip_time = self._check_with_socket(sock_info) if not response.awaitable: @@ -275,11 +275,10 @@ class Monitor(MonitorBase): response = conn._hello( cluster_time, self._server_description.topology_version, - self._settings.heartbeat_frequency, - None) + self._settings.heartbeat_frequency) else: # New connection handshake or polling hello (MongoDB <4.4). - response = conn._hello(cluster_time, None, None, None) + response = conn._hello(cluster_time, None, None) return response, time.monotonic() - start @@ -388,7 +387,7 @@ class _RttMonitor(MonitorBase): def _ping(self): """Run a "hello" command and return the RTT.""" - with self._pool.get_socket({}) as sock_info: + with self._pool.get_socket() as sock_info: if self._executor._stopped: raise Exception('_RttMonitor closed') start = time.monotonic() diff --git a/pymongo/pool.py b/pymongo/pool.py index 84661c487..6fe9d024d 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -275,7 +275,8 @@ class PoolOptions(object): '__ssl_context', '__tls_allow_invalid_hostnames', '__event_listeners', '__appname', '__driver', '__metadata', '__compression_settings', '__max_connecting', - '__pause_enabled', '__server_api', '__load_balanced') + '__pause_enabled', '__server_api', '__load_balanced', + '__credentials') def __init__(self, max_pool_size=MAX_POOL_SIZE, min_pool_size=MIN_POOL_SIZE, @@ -285,7 +286,8 @@ class PoolOptions(object): tls_allow_invalid_hostnames=False, event_listeners=None, appname=None, driver=None, compression_settings=None, max_connecting=MAX_CONNECTING, - pause_enabled=True, server_api=None, load_balanced=None): + pause_enabled=True, server_api=None, load_balanced=None, + credentials=None): self.__max_pool_size = max_pool_size self.__min_pool_size = min_pool_size self.__max_idle_time_seconds = max_idle_time_seconds @@ -302,6 +304,7 @@ class PoolOptions(object): self.__pause_enabled = pause_enabled self.__server_api = server_api self.__load_balanced = load_balanced + self.__credentials = credentials self.__metadata = copy.deepcopy(_METADATA) if appname: self.__metadata['application'] = {'name': appname} @@ -325,6 +328,11 @@ class PoolOptions(object): self.__metadata['platform'] = "%s|%s" % ( _METADATA['platform'], driver.platform) + @property + def _credentials(self): + """A :class:`~pymongo.auth.MongoCredentials` instance or None.""" + return self.__credentials + @property def non_default_options(self): """The non-default options this pool was created with. @@ -457,25 +465,6 @@ class PoolOptions(object): return self.__load_balanced -def _negotiate_creds(all_credentials): - """Return one credential that needs mechanism negotiation, if any. - """ - if all_credentials: - for creds in all_credentials.values(): - if creds.mechanism == 'DEFAULT' and creds.username: - return creds - return None - - -def _speculative_context(all_credentials): - """Return the _AuthContext to use for speculative auth, if any. - """ - if all_credentials and len(all_credentials) == 1: - creds = next(iter(all_credentials.values())) - return auth._AuthContext.from_credentials(creds) - return None - - class _CancellationContext(object): def __init__(self): self._cancelled = False @@ -504,7 +493,7 @@ class SocketInfo(object): self.sock = sock self.address = address self.id = id - self.authset = set() + self.authed = set() self.closed = False self.last_checkin_time = time.monotonic() self.performed_handshake = False @@ -523,9 +512,8 @@ class SocketInfo(object): self.compression_context = None self.socket_checker = SocketChecker() # Support for mechanism negotiation on the initial handshake. - # Maps credential to saslSupportedMechs. - self.negotiated_mechanisms = {} - self.auth_ctx = {} + self.negotiated_mechs = None + self.auth_ctx = None # The pool's generation changes with each reset() so we can close # sockets created before the last reset. @@ -567,11 +555,10 @@ class SocketInfo(object): else: return SON([(HelloCompat.LEGACY_CMD, 1), ('helloOk', True)]) - def hello(self, all_credentials=None): - return self._hello(None, None, None, all_credentials) + def hello(self): + return self._hello(None, None, None) - def _hello(self, cluster_time, topology_version, - heartbeat_frequency, all_credentials): + def _hello(self, cluster_time, topology_version, heartbeat_frequency): cmd = self.hello_cmd() performing_handshake = not self.performed_handshake awaitable = False @@ -594,14 +581,15 @@ class SocketInfo(object): if not performing_handshake and cluster_time is not None: cmd['$clusterTime'] = cluster_time - # XXX: Simplify in PyMongo 4.0 when all_credentials is always a single - # unchangeable value per MongoClient. - creds = _negotiate_creds(all_credentials) + creds = self.opts._credentials if creds: - cmd['saslSupportedMechs'] = creds.source + '.' + creds.username - auth_ctx = _speculative_context(all_credentials) - if auth_ctx: - cmd['speculativeAuthenticate'] = auth_ctx.speculate_command() + if creds.mechanism == 'DEFAULT' and creds.username: + cmd['saslSupportedMechs'] = creds.source + '.' + creds.username + auth_ctx = auth._AuthContext.from_credentials(creds) + if auth_ctx: + cmd['speculativeAuthenticate'] = auth_ctx.speculate_command() + else: + auth_ctx = None doc = self.command('admin', cmd, publish_events=False, exhaust_allowed=awaitable) @@ -628,11 +616,11 @@ class SocketInfo(object): self.op_msg_enabled = True if creds: - self.negotiated_mechanisms[creds] = hello.sasl_supported_mechs + self.negotiated_mechs = hello.sasl_supported_mechs if auth_ctx: auth_ctx.parse_response(hello) if auth_ctx.speculate_succeeded(): - self.auth_ctx[auth_ctx.credentials] = auth_ctx + self.auth_ctx = auth_ctx if self.opts.load_balanced: if not hello.service_id: raise ConfigurationError( @@ -799,41 +787,21 @@ class SocketInfo(object): helpers._check_command_response(result, self.max_wire_version) return result - def check_auth(self, all_credentials): - """Update this socket's authentication. + def authenticate(self): + """Authenticate to the server if needed. - Log in or out to bring this socket's credentials up to date with - those provided. Can raise ConnectionFailure or OperationFailure. - - :Parameters: - - `all_credentials`: dict, maps auth source to MongoCredential. + Can raise ConnectionFailure or OperationFailure. """ - if all_credentials: - for credentials in all_credentials.values(): - if credentials not in self.authset: - self.authenticate(credentials) - # CMAP spec says to publish the ready event only after authenticating # the connection. if not self.ready: + creds = self.opts._credentials + if creds: + auth.authenticate(creds, self) self.ready = True if self.enabled_for_cmap: self.listeners.publish_connection_ready(self.address, self.id) - def authenticate(self, credentials): - """Log in to the server and store these credentials in `authset`. - - Can raise ConnectionFailure or OperationFailure. - - :Parameters: - - `credentials`: A MongoCredential. - """ - auth.authenticate(credentials, self) - self.authset.add(credentials) - # negotiated_mechanisms are no longer needed. - self.negotiated_mechanisms.pop(credentials, None) - self.auth_ctx.pop(credentials, None) - def validate_session(self, client, session): """Validate this session before use with client. @@ -1245,7 +1213,7 @@ class Pool: def stale_generation(self, gen, service_id): return self.gen.stale(gen, service_id) - def remove_stale_sockets(self, reference_generation, all_credentials): + def remove_stale_sockets(self, reference_generation): """Removes stale sockets then adds new ones if pool is too small and has not been reset. The `reference_generation` argument specifies the `generation` at the point in time this operation was requested on the @@ -1281,7 +1249,7 @@ class Pool: return self._pending += 1 incremented = True - sock_info = self.connect(all_credentials) + sock_info = self.connect() with self.lock: # Close connection and return if the pool was reset during # socket creation or while acquiring the pool lock. @@ -1300,7 +1268,7 @@ class Pool: self.requests -= 1 self.size_cond.notify() - def connect(self, all_credentials=None): + def connect(self): """Connect to Mongo and return a new SocketInfo. Can raise ConnectionFailure. @@ -1331,10 +1299,10 @@ class Pool: sock_info = SocketInfo(sock, self, self.address, conn_id) try: if self.handshake: - sock_info.hello(all_credentials) + sock_info.hello() self.is_writable = sock_info.is_writable - sock_info.check_auth(all_credentials) + sock_info.authenticate() except BaseException: sock_info.close_socket(ConnectionClosedReason.ERROR) raise @@ -1342,7 +1310,7 @@ class Pool: return sock_info @contextlib.contextmanager - def get_socket(self, all_credentials, handler=None): + def get_socket(self, handler=None): """Get a socket from the pool. Use with a "with" statement. Returns a :class:`SocketInfo` object wrapping a connected @@ -1350,25 +1318,20 @@ class Pool: This method should always be used in a with-statement:: - with pool.get_socket(credentials) as socket_info: + with pool.get_socket() as socket_info: socket_info.send_message(msg) data = socket_info.receive_message(op_code, request_id) - The socket is logged in or out as needed to match ``all_credentials`` - using the correct authentication mechanism for the server's wire - protocol version. - Can raise ConnectionFailure or OperationFailure. :Parameters: - - `all_credentials`: dict, maps auth source to MongoCredential. - `handler` (optional): A _MongoClientErrorHandler. """ listeners = self.opts._event_listeners if self.enabled_for_cmap: listeners.publish_connection_check_out_started(self.address) - sock_info = self._get_socket(all_credentials) + sock_info = self._get_socket() if self.enabled_for_cmap: listeners.publish_connection_checked_out( self.address, sock_info.id) @@ -1407,7 +1370,7 @@ class Pool: _raise_connection_failure( self.address, AutoReconnect('connection pool paused')) - def _get_socket(self, all_credentials): + def _get_socket(self): """Get or create a SocketInfo. Can raise ConnectionFailure.""" # We use the pid here to avoid issues with fork / multiprocessing. # See test.test_client:TestClient.test_fork for an example of @@ -1480,12 +1443,11 @@ class Pool: continue else: # We need to create a new connection try: - sock_info = self.connect(all_credentials) + sock_info = self.connect() finally: with self._max_connecting_cond: self._pending -= 1 self._max_connecting_cond.notify() - sock_info.check_auth(all_credentials) except BaseException: if sock_info: # We checked out a socket but authentication failed. diff --git a/pymongo/server.py b/pymongo/server.py index 0a487e8c4..2a0a7267b 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -77,9 +77,9 @@ class Server(object): Can raise ConnectionFailure, OperationFailure, etc. :Parameters: + - `sock_info` - A SocketInfo instance. - `operation`: A _Query or _GetMore object. - `set_secondary_okay`: Pass to operation.get_message. - - `all_credentials`: dict, maps auth source to MongoCredential. - `listeners`: Instance of _EventListeners or None. - `unpack_res`: A callable that decodes the wire protocol response. """ @@ -200,8 +200,8 @@ class Server(object): return response - def get_socket(self, all_credentials, handler=None): - return self.pool.get_socket(all_credentials, handler) + def get_socket(self, handler=None): + return self.pool.get_socket(handler) @property def description(self): diff --git a/pymongo/topology.py b/pymongo/topology.py index 6f26cff61..021a1dee6 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -444,7 +444,7 @@ class Topology(object): return self._description.known_servers return self._description.readable_servers - def update_pool(self, all_credentials): + def update_pool(self): # Remove any stale sockets and add new sockets if pool is too small. servers = [] with self._lock: @@ -456,7 +456,7 @@ class Topology(object): for server, generation in servers: try: - server.pool.remove_stale_sockets(generation, all_credentials) + server.pool.remove_stale_sockets(generation) except PyMongoError as exc: ctx = _ErrorContext(exc, 0, generation, False, None) self.handle_error(server.description.address, ctx) diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 8b1ece8ad..1494fbedc 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -40,7 +40,7 @@ class MockPool(Pool): Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs) @contextlib.contextmanager - def get_socket(self, all_credentials, handler=None): + def get_socket(self, handler=None): client = self.client host_and_port = '%s:%s' % (self.mock_host, self.mock_port) if host_and_port in client.mock_down_hosts: @@ -51,7 +51,7 @@ class MockPool(Pool): + client.mock_members + client.mock_mongoses), "bad host: %s" % host_and_port - with Pool.get_socket(self, all_credentials, handler) as sock_info: + with Pool.get_socket(self, handler) as sock_info: sock_info.mock_host = self.mock_host sock_info.mock_port = self.mock_port yield sock_info diff --git a/test/test_auth.py b/test/test_auth.py index d0724dce7..35f198574 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -30,6 +30,7 @@ from pymongo.read_preferences import ReadPreference from pymongo.saslprep import HAVE_STRINGPREP from test import client_context, IntegrationTest, SkipTest, unittest, Version from test.utils import (delay, + get_pool, ignore_deprecations, single_client, rs_or_single_client, @@ -521,10 +522,12 @@ class TestSCRAM(IntegrationTest): def test_cache(self): client = single_client() + credentials = client.options.pool_options._credentials + cache = credentials.cache + self.assertIsNotNone(cache) + self.assertIsNone(cache.data) # Force authentication. client.admin.command('ping') - all_credentials = client._MongoClient__all_credentials - credentials = all_credentials.get('admin') cache = credentials.cache self.assertIsNotNone(cache) data = cache.data @@ -536,19 +539,6 @@ class TestSCRAM(IntegrationTest): self.assertIsInstance(salt, bytes) self.assertIsInstance(iterations, int) - pool = next(iter(client._topology._servers.values()))._pool - with pool.get_socket(all_credentials) as sock_info: - authset = sock_info.authset - cached = set(all_credentials.values()) - self.assertEqual(len(cached), 1) - self.assertFalse(authset - cached) - self.assertFalse(cached - authset) - - sock_credentials = next(iter(authset)) - sock_cache = sock_credentials.cache - self.assertIsNotNone(sock_cache) - self.assertEqual(sock_cache.data, data) - def test_scram_threaded(self): coll = client_context.client.db.test diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 8bf0dcb21..e78b4b209 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -44,7 +44,7 @@ def create_test(test_case): self.assertRaises(Exception, MongoClient, uri, connect=False) else: client = MongoClient(uri, connect=False) - credentials = client._MongoClient__options._credentials + credentials = client.options.pool_options._credentials if credential is None: self.assertIsNone(credentials) else: diff --git a/test/test_client.py b/test/test_client.py index 8c89a4548..8db1cb562 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -498,7 +498,7 @@ class TestClient(IntegrationTest): client = rs_or_single_client() server = client._get_topology().select_server( readable_server_selector) - with server._pool.get_socket({}) as sock_info: + with server._pool.get_socket() as sock_info: pass self.assertEqual(1, len(server._pool.sockets)) self.assertTrue(sock_info in server._pool.sockets) @@ -511,7 +511,7 @@ class TestClient(IntegrationTest): minPoolSize=1) server = client._get_topology().select_server( readable_server_selector) - with server._pool.get_socket({}) as sock_info: + with server._pool.get_socket() as sock_info: pass # When the reaper runs at the same time as the get_socket, two # sockets could be created and checked into the pool. @@ -530,7 +530,7 @@ class TestClient(IntegrationTest): maxPoolSize=1) server = client._get_topology().select_server( readable_server_selector) - with server._pool.get_socket({}) as sock_info: + with server._pool.get_socket() as sock_info: pass # When the reaper runs at the same time as the get_socket, # maxPoolSize=1 should prevent two sockets from being created. @@ -547,11 +547,11 @@ class TestClient(IntegrationTest): client = rs_or_single_client(maxIdleTimeMS=500) server = client._get_topology().select_server( readable_server_selector) - with server._pool.get_socket({}) as sock_info_one: + with server._pool.get_socket() as sock_info_one: pass # Assert that the pool does not close sockets prematurely. time.sleep(.300) - with server._pool.get_socket({}) as sock_info_two: + with server._pool.get_socket() as sock_info_two: pass self.assertIs(sock_info_one, sock_info_two) wait_until( @@ -574,7 +574,7 @@ class TestClient(IntegrationTest): "pool initialized with 10 sockets") # Assert that if a socket is closed, a new one takes its place - with server._pool.get_socket({}) as sock_info: + with server._pool.get_socket() as sock_info: sock_info.close_socket(None) wait_until(lambda: 10 == len(server._pool.sockets), "a closed socket gets replaced from the pool") @@ -586,12 +586,12 @@ class TestClient(IntegrationTest): client = rs_or_single_client(maxIdleTimeMS=500) server = client._get_topology().select_server( readable_server_selector) - with server._pool.get_socket({}) as sock_info: + with server._pool.get_socket() as sock_info: pass self.assertEqual(1, len(server._pool.sockets)) time.sleep(1) # Sleep so that the socket becomes stale. - with server._pool.get_socket({}) as new_sock_info: + with server._pool.get_socket() as new_sock_info: self.assertNotEqual(sock_info, new_sock_info) self.assertEqual(1, len(server._pool.sockets)) self.assertFalse(sock_info in server._pool.sockets) @@ -601,11 +601,11 @@ class TestClient(IntegrationTest): client = rs_or_single_client() server = client._get_topology().select_server( readable_server_selector) - with server._pool.get_socket({}) as sock_info: + with server._pool.get_socket() as sock_info: pass self.assertEqual(1, len(server._pool.sockets)) time.sleep(1) - with server._pool.get_socket({}) as new_sock_info: + with server._pool.get_socket() as new_sock_info: self.assertEqual(sock_info, new_sock_info) self.assertEqual(1, len(server._pool.sockets)) @@ -1106,7 +1106,7 @@ class TestClient(IntegrationTest): def test_socketKeepAlive(self): pool = get_pool(self.client) - with pool.get_socket({}) as sock_info: + with pool.get_socket() as sock_info: keepalive = sock_info.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) self.assertTrue(keepalive) @@ -1325,8 +1325,8 @@ class TestClient(IntegrationTest): socket_info = one(pool.sockets) socket_info.sock.close() - # SocketInfo.check_auth logs in with the new credential, but gets a - # socket.error. Should be reraised as AutoReconnect. + # SocketInfo.authenticate logs, but gets a socket.error. Should be + # reraised as AutoReconnect. self.assertRaises(AutoReconnect, c.test.collection.find_one) # No semaphore leak, the pool is allowed to make a new socket. @@ -1521,8 +1521,7 @@ class TestClient(IntegrationTest): try: while True: for _ in range(10): - client._topology.update_pool( - client._MongoClient__all_credentials) + client._topology.update_pool() if generation != pool.gen.get_overall(): break finally: diff --git a/test/test_cmap.py b/test/test_cmap.py index d08cc24a5..20ed7f31e 100644 --- a/test/test_cmap.py +++ b/test/test_cmap.py @@ -120,7 +120,7 @@ class TestCMAP(IntegrationTest): def check_out(self, op): """Run the 'checkOut' operation.""" label = op['label'] - with self.pool.get_socket({}) as sock_info: + with self.pool.get_socket() as sock_info: # Call 'pin_cursor' so we can hold the socket. sock_info.pin_cursor() if label: @@ -452,7 +452,7 @@ class TestCMAP(IntegrationTest): self.assertEqual(1, listener.event_count(PoolClearedEvent)) self.assertEqual(PoolState.READY, pool.state) # Checking out a connection should succeed - with pool.get_socket({}): + with pool.get_socket(): pass diff --git a/test/test_pooling.py b/test/test_pooling.py index b8f3cf190..4f0ac3584 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -118,7 +118,7 @@ class SocketGetter(MongoThread): self.state = 'get_socket' # Call 'pin_cursor' so we can hold the socket. - with self.pool.get_socket({}) as sock: + with self.pool.get_socket() as sock: sock.pin_cursor() self.sock = sock @@ -196,10 +196,10 @@ class TestPooling(_TestPoolingBase): # Test Pool's _check_closed() method doesn't close a healthy socket. cx_pool = self.create_pool(max_pool_size=10) cx_pool._check_interval_seconds = 0 # Always check. - with cx_pool.get_socket({}) as sock_info: + with cx_pool.get_socket() as sock_info: pass - with cx_pool.get_socket({}) as new_sock_info: + with cx_pool.get_socket() as new_sock_info: self.assertEqual(sock_info, new_sock_info) self.assertEqual(1, len(cx_pool.sockets)) @@ -208,11 +208,11 @@ class TestPooling(_TestPoolingBase): # get_socket() returns socket after a non-network error. cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1) with self.assertRaises(ZeroDivisionError): - with cx_pool.get_socket({}) as sock_info: + with cx_pool.get_socket() as sock_info: 1 / 0 # Socket was returned, not closed. - with cx_pool.get_socket({}) as new_sock_info: + with cx_pool.get_socket() as new_sock_info: self.assertEqual(sock_info, new_sock_info) self.assertEqual(1, len(cx_pool.sockets)) @@ -221,7 +221,7 @@ class TestPooling(_TestPoolingBase): # Test that Pool removes explicitly closed socket. cx_pool = self.create_pool() - with cx_pool.get_socket({}) as sock_info: + with cx_pool.get_socket() as sock_info: # Use SocketInfo's API to close the socket. sock_info.close_socket(None) @@ -233,20 +233,20 @@ class TestPooling(_TestPoolingBase): cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1) cx_pool._check_interval_seconds = 0 # Always check. - with cx_pool.get_socket({}) as sock_info: + with cx_pool.get_socket() as sock_info: # Simulate a closed socket without telling the SocketInfo it's # closed. sock_info.sock.close() self.assertTrue(sock_info.socket_closed()) - with cx_pool.get_socket({}) as new_sock_info: + with cx_pool.get_socket() as new_sock_info: self.assertEqual(0, len(cx_pool.sockets)) self.assertNotEqual(sock_info, new_sock_info) self.assertEqual(1, len(cx_pool.sockets)) # Semaphore was released. - with cx_pool.get_socket({}): + with cx_pool.get_socket(): pass def test_socket_closed(self): @@ -290,7 +290,7 @@ class TestPooling(_TestPoolingBase): def test_return_socket_after_reset(self): pool = self.create_pool() - with pool.get_socket({}) as sock: + with pool.get_socket() as sock: self.assertEqual(pool.active_sockets, 1) self.assertEqual(pool.operation_count, 1) pool.reset() @@ -309,7 +309,7 @@ class TestPooling(_TestPoolingBase): cx_pool._check_interval_seconds = 0 # Always check. self.addCleanup(cx_pool.close) - with cx_pool.get_socket({}) as sock_info: + with cx_pool.get_socket() as sock_info: # Simulate a closed socket without telling the SocketInfo it's # closed. sock_info.sock.close() @@ -317,12 +317,12 @@ class TestPooling(_TestPoolingBase): # Swap pool's address with a bad one. address, cx_pool.address = cx_pool.address, ('foo.com', 1234) with self.assertRaises(AutoReconnect): - with cx_pool.get_socket({}): + with cx_pool.get_socket(): pass # Back to normal, semaphore was correctly released. cx_pool.address = address - with cx_pool.get_socket({}): + with cx_pool.get_socket(): pass def test_wait_queue_timeout(self): @@ -331,10 +331,10 @@ class TestPooling(_TestPoolingBase): max_pool_size=1, wait_queue_timeout=wait_queue_timeout) self.addCleanup(pool.close) - with pool.get_socket({}) as sock_info: + with pool.get_socket() as sock_info: start = time.time() with self.assertRaises(ConnectionFailure): - with pool.get_socket({}): + with pool.get_socket(): pass duration = time.time() - start @@ -349,7 +349,7 @@ class TestPooling(_TestPoolingBase): self.addCleanup(pool.close) # Reach max_size. - with pool.get_socket({}) as s1: + with pool.get_socket() as s1: t = SocketGetter(self.c, pool) t.start() while t.state != 'get_socket': @@ -370,7 +370,7 @@ class TestPooling(_TestPoolingBase): socks = [] for _ in range(2): # Call 'pin_cursor' so we can hold the socket. - with pool.get_socket({}) as sock: + with pool.get_socket() as sock: sock.pin_cursor() socks.append(sock) @@ -515,7 +515,7 @@ class TestPoolMaxSize(_TestPoolingBase): # socket from pool" instead of AutoReconnect. for i in range(2): with self.assertRaises(AutoReconnect) as context: - with test_pool.get_socket({}): + with test_pool.get_socket(): pass # Testing for AutoReconnect instead of ConnectionFailure, above, diff --git a/test/utils.py b/test/utils.py index bdea5c69c..efc6e2487 100644 --- a/test/utils.py +++ b/test/utils.py @@ -269,7 +269,7 @@ class MockPool(object): def stale_generation(self, gen, service_id): return self.gen.stale(gen, service_id) - def get_socket(self, all_credentials, handler=None): + def get_socket(self, handler=None): return MockSocketInfo() def return_socket(self, *args, **kwargs):