This commit is contained in:
Noah Stapp 2026-05-18 18:17:44 +00:00 committed by GitHub
commit 6c75afb5ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 131 additions and 95 deletions

View File

@ -1065,35 +1065,44 @@ class Pool:
raise
conn = AsyncConnection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
async with self.lock:
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
await conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)
await conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as e:
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
await conn.close_conn(ConnectionClosedReason.ERROR)
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
await conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)
await conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as e:
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
await conn.close_conn(ConnectionClosedReason.ERROR)
raise
if handler:
await handler.client._topology.receive_cluster_time(conn._cluster_time)
return conn
# Catch cancellations that interrupt outside the inner try block above
except BaseException:
if not conn.closed:
try:
await conn.close_conn(ConnectionClosedReason.ERROR)
except BaseException: # noqa: S110
pass
raise
if handler:
await handler.client._topology.receive_cluster_time(conn._cluster_time)
return conn
@contextlib.asynccontextmanager
async def checkout(
self, handler: Optional[_MongoClientErrorHandler] = None

View File

@ -207,6 +207,7 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
sock = socket.socket(af, socktype, proto)
# Fallback when SOCK_CLOEXEC isn't available.
_set_non_inheritable_non_atomic(sock.fileno())
sock_returned = False
try:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# CSOT: apply timeout to socket connect.
@ -223,14 +224,18 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
asyncio.get_running_loop().sock_connect(sock, sa), timeout=timeout
)
sock.settimeout(timeout)
# Set immediately before return. Do not insert an await between this and the return
sock_returned = True
return sock
except asyncio.TimeoutError as e:
sock.close()
err = socket.timeout("timed out")
err.__cause__ = e
except OSError as e:
sock.close()
err = e # type: ignore[assignment]
finally:
# Always close the socket if it wasn't returned to avoid leaks.
if not sock_returned:
sock.close()
if err is not None:
raise err
@ -309,46 +314,58 @@ async def _configured_protocol_interface(
sock = await _async_create_connection(address, options)
ssl_context = options._ssl_context
timeout = options.socket_timeout
if ssl_context is None:
return AsyncNetworkingInterface(
await asyncio.get_running_loop().create_connection(
lambda: PyMongoProtocol(timeout=timeout), sock=sock
)
)
host = address[0]
# Create the Protocol early to prevent asyncio resource leaks during cleanup path
protocol = PyMongoProtocol(timeout=timeout)
sock_adopted = False
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
lambda: PyMongoProtocol(timeout=timeout),
sock=sock,
server_hostname=host,
ssl=ssl_context,
)
except _CertificateError:
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, *SSLErrors) as exc:
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
except _CertificateError:
transport.abort()
raise
if ssl_context is None:
result = await asyncio.get_running_loop().create_connection(lambda: protocol, sock=sock)
sock_adopted = True
return AsyncNetworkingInterface(result)
return AsyncNetworkingInterface((transport, protocol))
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
transport, _ = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
lambda: protocol,
sock=sock,
server_hostname=host,
ssl=ssl_context,
)
sock_adopted = True
except _CertificateError:
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, *SSLErrors) as exc:
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(
address, exc, "SSL handshake failed: ", timeout_details=details
)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
except _CertificateError:
transport.abort()
raise
return AsyncNetworkingInterface((transport, protocol))
finally:
if not sock_adopted:
# If the protocol owns the transport, it also adopted the socket and needs to be cleaned up from the transport
if protocol.transport is not None:
protocol.transport.abort()
# Otherwise the socket was never adopted, close it directly
else:
sock.close()
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:

View File

@ -1061,35 +1061,44 @@ class Pool:
raise
conn = Connection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
with self.lock:
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)
conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as e:
with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
conn.close_conn(ConnectionClosedReason.ERROR)
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)
conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as e:
with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
conn.close_conn(ConnectionClosedReason.ERROR)
raise
if handler:
handler.client._topology.receive_cluster_time(conn._cluster_time)
return conn
# Catch cancellations that interrupt outside the inner try block above
except BaseException:
if not conn.closed:
try:
conn.close_conn(ConnectionClosedReason.ERROR)
except BaseException: # noqa: S110
pass
raise
if handler:
handler.client._topology.receive_cluster_time(conn._cluster_time)
return conn
@contextlib.contextmanager
def checkout(
self, handler: Optional[_MongoClientErrorHandler] = None

View File

@ -108,11 +108,6 @@ filterwarnings = [
# pytest-asyncio known issue: https://github.com/pytest-dev/pytest-asyncio/issues/1032
"module:.*WindowsSelectorEventLoopPolicy:DeprecationWarning",
"module:.*et_event_loop_policy:DeprecationWarning",
# TODO: Remove as part of PYTHON-3923.
"module:unclosed <socket.socket:ResourceWarning",
"module:unclosed <ssl.SSLSocket:ResourceWarning",
"module:unclosed <socket object:ResourceWarning",
"module:unclosed transport:ResourceWarning",
# pytest-asyncio known issue: https://github.com/pytest-dev/pytest-asyncio/issues/724
"module:unclosed event loop:ResourceWarning",
# https://github.com/dateutil/dateutil/issues/1314

View File

@ -1217,6 +1217,8 @@ def setup():
def teardown():
global_knobs.disable()
if client_context.client is not None:
client_context.client.close()
garbage = []
for g in gc.garbage:
garbage.append(f"GARBAGE: {g!r}")

View File

@ -1233,6 +1233,8 @@ async def async_setup():
async def async_teardown():
global_knobs.disable()
if async_client_context.client is not None:
await async_client_context.client.close()
garbage = []
for g in gc.garbage:
garbage.append(f"GARBAGE: {g!r}")

View File

@ -172,6 +172,7 @@ class _TestPoolingBase(AsyncIntegrationTest):
kwargs["server_api"] = pool_options.server_api
pool = Pool(pair, PoolOptions(*args, **kwargs))
await pool.ready()
self.addAsyncCleanup(pool.close)
return pool

View File

@ -172,6 +172,7 @@ class _TestPoolingBase(IntegrationTest):
kwargs["server_api"] = pool_options.server_api
pool = Pool(pair, PoolOptions(*args, **kwargs))
pool.ready()
self.addCleanup(pool.close)
return pool