diff --git a/pymongo/pool_shared.py b/pymongo/pool_shared.py index 1ad84478d..2e4522ea3 100644 --- a/pymongo/pool_shared.py +++ b/pymongo/pool_shared.py @@ -312,15 +312,14 @@ async def _configured_protocol_interface( Sets protocol's SSL and timeout options. """ sock = await _async_create_connection(address, options) + ssl_context = options._ssl_context + timeout = options.socket_timeout + # Create the Protocol early to prevent asyncio resource leaks during cleanup path + protocol = PyMongoProtocol(timeout=timeout) sock_adopted = False try: - ssl_context = options._ssl_context - timeout = options.socket_timeout - if ssl_context is None: - result = await asyncio.get_running_loop().create_connection( - lambda: PyMongoProtocol(timeout=timeout), sock=sock - ) + result = await asyncio.get_running_loop().create_connection(lambda: protocol, sock=sock) sock_adopted = True return AsyncNetworkingInterface(result) @@ -328,8 +327,8 @@ async def _configured_protocol_interface( try: # We have to pass hostname / ip address to wrap_socket # to use SSLContext.check_hostname. - transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload] - lambda: PyMongoProtocol(timeout=timeout), + transport, _ = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload] + lambda: protocol, sock=sock, server_hostname=host, ssl=ssl_context, @@ -360,12 +359,13 @@ async def _configured_protocol_interface( return AsyncNetworkingInterface((transport, protocol)) finally: - # If cancellation or any exception lands between socket creation and - # transport adoption, asyncio.create_connection has not registered - # cleanup for the sock. - # Close it ourselves to prevent leaks. if not sock_adopted: - sock.close() + # If the protocol owns the transport, it also adopted the socket and needs to be cleaned up from the transport + if protocol.transport is not None: + protocol.transport.abort() + # Otherwise the socket was never adopted, close it directly + else: + sock.close() def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: