Fix asyncio transport socket leak
This commit is contained in:
parent
eceb69faa7
commit
b01931495a
@ -312,15 +312,14 @@ async def _configured_protocol_interface(
|
||||
Sets protocol's SSL and timeout options.
|
||||
"""
|
||||
sock = await _async_create_connection(address, options)
|
||||
ssl_context = options._ssl_context
|
||||
timeout = options.socket_timeout
|
||||
# Create the Protocol early to prevent asyncio resource leaks during cleanup path
|
||||
protocol = PyMongoProtocol(timeout=timeout)
|
||||
sock_adopted = False
|
||||
try:
|
||||
ssl_context = options._ssl_context
|
||||
timeout = options.socket_timeout
|
||||
|
||||
if ssl_context is None:
|
||||
result = await asyncio.get_running_loop().create_connection(
|
||||
lambda: PyMongoProtocol(timeout=timeout), sock=sock
|
||||
)
|
||||
result = await asyncio.get_running_loop().create_connection(lambda: protocol, sock=sock)
|
||||
sock_adopted = True
|
||||
return AsyncNetworkingInterface(result)
|
||||
|
||||
@ -328,8 +327,8 @@ async def _configured_protocol_interface(
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
|
||||
lambda: PyMongoProtocol(timeout=timeout),
|
||||
transport, _ = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
|
||||
lambda: protocol,
|
||||
sock=sock,
|
||||
server_hostname=host,
|
||||
ssl=ssl_context,
|
||||
@ -360,12 +359,13 @@ async def _configured_protocol_interface(
|
||||
|
||||
return AsyncNetworkingInterface((transport, protocol))
|
||||
finally:
|
||||
# If cancellation or any exception lands between socket creation and
|
||||
# transport adoption, asyncio.create_connection has not registered
|
||||
# cleanup for the sock.
|
||||
# Close it ourselves to prevent leaks.
|
||||
if not sock_adopted:
|
||||
sock.close()
|
||||
# If the protocol owns the transport, it also adopted the socket and needs to be cleaned up from the transport
|
||||
if protocol.transport is not None:
|
||||
protocol.transport.abort()
|
||||
# Otherwise the socket was never adopted, close it directly
|
||||
else:
|
||||
sock.close()
|
||||
|
||||
|
||||
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user