Merge b01931495a into 9a8e34c726
This commit is contained in:
commit
6c75afb5ba
@ -1065,35 +1065,44 @@ class Pool:
|
||||
raise
|
||||
|
||||
conn = AsyncConnection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
|
||||
async with self.lock:
|
||||
self.active_contexts.add(conn.cancel_context)
|
||||
self.active_contexts.discard(tmp_context)
|
||||
if tmp_context.cancelled:
|
||||
conn.cancel_context.cancel()
|
||||
completed_hello = False
|
||||
try:
|
||||
if not self.is_sdam:
|
||||
await conn.hello()
|
||||
completed_hello = True
|
||||
self.is_writable = conn.is_writable
|
||||
if handler:
|
||||
handler.contribute_socket(conn, completed_handshake=False)
|
||||
|
||||
await conn.authenticate()
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as e:
|
||||
async with self.lock:
|
||||
self.active_contexts.discard(conn.cancel_context)
|
||||
if not completed_hello:
|
||||
self._handle_connection_error(e)
|
||||
await conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
self.active_contexts.add(conn.cancel_context)
|
||||
self.active_contexts.discard(tmp_context)
|
||||
if tmp_context.cancelled:
|
||||
conn.cancel_context.cancel()
|
||||
completed_hello = False
|
||||
try:
|
||||
if not self.is_sdam:
|
||||
await conn.hello()
|
||||
completed_hello = True
|
||||
self.is_writable = conn.is_writable
|
||||
if handler:
|
||||
handler.contribute_socket(conn, completed_handshake=False)
|
||||
|
||||
await conn.authenticate()
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as e:
|
||||
async with self.lock:
|
||||
self.active_contexts.discard(conn.cancel_context)
|
||||
if not completed_hello:
|
||||
self._handle_connection_error(e)
|
||||
await conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
raise
|
||||
|
||||
if handler:
|
||||
await handler.client._topology.receive_cluster_time(conn._cluster_time)
|
||||
|
||||
return conn
|
||||
# Catch cancellations that interrupt outside the inner try block above
|
||||
except BaseException:
|
||||
if not conn.closed:
|
||||
try:
|
||||
await conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
except BaseException: # noqa: S110
|
||||
pass
|
||||
raise
|
||||
|
||||
if handler:
|
||||
await handler.client._topology.receive_cluster_time(conn._cluster_time)
|
||||
|
||||
return conn
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def checkout(
|
||||
self, handler: Optional[_MongoClientErrorHandler] = None
|
||||
|
||||
@ -207,6 +207,7 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
|
||||
sock = socket.socket(af, socktype, proto)
|
||||
# Fallback when SOCK_CLOEXEC isn't available.
|
||||
_set_non_inheritable_non_atomic(sock.fileno())
|
||||
sock_returned = False
|
||||
try:
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
# CSOT: apply timeout to socket connect.
|
||||
@ -223,14 +224,18 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
|
||||
asyncio.get_running_loop().sock_connect(sock, sa), timeout=timeout
|
||||
)
|
||||
sock.settimeout(timeout)
|
||||
# Set immediately before return. Do not insert an await between this and the return
|
||||
sock_returned = True
|
||||
return sock
|
||||
except asyncio.TimeoutError as e:
|
||||
sock.close()
|
||||
err = socket.timeout("timed out")
|
||||
err.__cause__ = e
|
||||
except OSError as e:
|
||||
sock.close()
|
||||
err = e # type: ignore[assignment]
|
||||
finally:
|
||||
# Always close the socket if it wasn't returned to avoid leaks.
|
||||
if not sock_returned:
|
||||
sock.close()
|
||||
|
||||
if err is not None:
|
||||
raise err
|
||||
@ -309,46 +314,58 @@ async def _configured_protocol_interface(
|
||||
sock = await _async_create_connection(address, options)
|
||||
ssl_context = options._ssl_context
|
||||
timeout = options.socket_timeout
|
||||
|
||||
if ssl_context is None:
|
||||
return AsyncNetworkingInterface(
|
||||
await asyncio.get_running_loop().create_connection(
|
||||
lambda: PyMongoProtocol(timeout=timeout), sock=sock
|
||||
)
|
||||
)
|
||||
|
||||
host = address[0]
|
||||
# Create the Protocol early to prevent asyncio resource leaks during cleanup path
|
||||
protocol = PyMongoProtocol(timeout=timeout)
|
||||
sock_adopted = False
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
|
||||
lambda: PyMongoProtocol(timeout=timeout),
|
||||
sock=sock,
|
||||
server_hostname=host,
|
||||
ssl=ssl_context,
|
||||
)
|
||||
except _CertificateError:
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, *SSLErrors) as exc:
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
||||
details = _get_timeout_details(options)
|
||||
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
|
||||
if (
|
||||
ssl_context.verify_mode
|
||||
and not ssl_context.check_hostname
|
||||
and not options.tls_allow_invalid_hostnames
|
||||
):
|
||||
try:
|
||||
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
|
||||
except _CertificateError:
|
||||
transport.abort()
|
||||
raise
|
||||
if ssl_context is None:
|
||||
result = await asyncio.get_running_loop().create_connection(lambda: protocol, sock=sock)
|
||||
sock_adopted = True
|
||||
return AsyncNetworkingInterface(result)
|
||||
|
||||
return AsyncNetworkingInterface((transport, protocol))
|
||||
host = address[0]
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
transport, _ = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
|
||||
lambda: protocol,
|
||||
sock=sock,
|
||||
server_hostname=host,
|
||||
ssl=ssl_context,
|
||||
)
|
||||
sock_adopted = True
|
||||
except _CertificateError:
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, *SSLErrors) as exc:
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
||||
details = _get_timeout_details(options)
|
||||
_raise_connection_failure(
|
||||
address, exc, "SSL handshake failed: ", timeout_details=details
|
||||
)
|
||||
if (
|
||||
ssl_context.verify_mode
|
||||
and not ssl_context.check_hostname
|
||||
and not options.tls_allow_invalid_hostnames
|
||||
):
|
||||
try:
|
||||
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
|
||||
except _CertificateError:
|
||||
transport.abort()
|
||||
raise
|
||||
|
||||
return AsyncNetworkingInterface((transport, protocol))
|
||||
finally:
|
||||
if not sock_adopted:
|
||||
# If the protocol owns the transport, it also adopted the socket and needs to be cleaned up from the transport
|
||||
if protocol.transport is not None:
|
||||
protocol.transport.abort()
|
||||
# Otherwise the socket was never adopted, close it directly
|
||||
else:
|
||||
sock.close()
|
||||
|
||||
|
||||
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
|
||||
|
||||
@ -1061,35 +1061,44 @@ class Pool:
|
||||
raise
|
||||
|
||||
conn = Connection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
|
||||
with self.lock:
|
||||
self.active_contexts.add(conn.cancel_context)
|
||||
self.active_contexts.discard(tmp_context)
|
||||
if tmp_context.cancelled:
|
||||
conn.cancel_context.cancel()
|
||||
completed_hello = False
|
||||
try:
|
||||
if not self.is_sdam:
|
||||
conn.hello()
|
||||
completed_hello = True
|
||||
self.is_writable = conn.is_writable
|
||||
if handler:
|
||||
handler.contribute_socket(conn, completed_handshake=False)
|
||||
|
||||
conn.authenticate()
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as e:
|
||||
with self.lock:
|
||||
self.active_contexts.discard(conn.cancel_context)
|
||||
if not completed_hello:
|
||||
self._handle_connection_error(e)
|
||||
conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
self.active_contexts.add(conn.cancel_context)
|
||||
self.active_contexts.discard(tmp_context)
|
||||
if tmp_context.cancelled:
|
||||
conn.cancel_context.cancel()
|
||||
completed_hello = False
|
||||
try:
|
||||
if not self.is_sdam:
|
||||
conn.hello()
|
||||
completed_hello = True
|
||||
self.is_writable = conn.is_writable
|
||||
if handler:
|
||||
handler.contribute_socket(conn, completed_handshake=False)
|
||||
|
||||
conn.authenticate()
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as e:
|
||||
with self.lock:
|
||||
self.active_contexts.discard(conn.cancel_context)
|
||||
if not completed_hello:
|
||||
self._handle_connection_error(e)
|
||||
conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
raise
|
||||
|
||||
if handler:
|
||||
handler.client._topology.receive_cluster_time(conn._cluster_time)
|
||||
|
||||
return conn
|
||||
# Catch cancellations that interrupt outside the inner try block above
|
||||
except BaseException:
|
||||
if not conn.closed:
|
||||
try:
|
||||
conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
except BaseException: # noqa: S110
|
||||
pass
|
||||
raise
|
||||
|
||||
if handler:
|
||||
handler.client._topology.receive_cluster_time(conn._cluster_time)
|
||||
|
||||
return conn
|
||||
|
||||
@contextlib.contextmanager
|
||||
def checkout(
|
||||
self, handler: Optional[_MongoClientErrorHandler] = None
|
||||
|
||||
@ -108,11 +108,6 @@ filterwarnings = [
|
||||
# pytest-asyncio known issue: https://github.com/pytest-dev/pytest-asyncio/issues/1032
|
||||
"module:.*WindowsSelectorEventLoopPolicy:DeprecationWarning",
|
||||
"module:.*et_event_loop_policy:DeprecationWarning",
|
||||
# TODO: Remove as part of PYTHON-3923.
|
||||
"module:unclosed <socket.socket:ResourceWarning",
|
||||
"module:unclosed <ssl.SSLSocket:ResourceWarning",
|
||||
"module:unclosed <socket object:ResourceWarning",
|
||||
"module:unclosed transport:ResourceWarning",
|
||||
# pytest-asyncio known issue: https://github.com/pytest-dev/pytest-asyncio/issues/724
|
||||
"module:unclosed event loop:ResourceWarning",
|
||||
# https://github.com/dateutil/dateutil/issues/1314
|
||||
|
||||
@ -1217,6 +1217,8 @@ def setup():
|
||||
|
||||
def teardown():
|
||||
global_knobs.disable()
|
||||
if client_context.client is not None:
|
||||
client_context.client.close()
|
||||
garbage = []
|
||||
for g in gc.garbage:
|
||||
garbage.append(f"GARBAGE: {g!r}")
|
||||
|
||||
@ -1233,6 +1233,8 @@ async def async_setup():
|
||||
|
||||
async def async_teardown():
|
||||
global_knobs.disable()
|
||||
if async_client_context.client is not None:
|
||||
await async_client_context.client.close()
|
||||
garbage = []
|
||||
for g in gc.garbage:
|
||||
garbage.append(f"GARBAGE: {g!r}")
|
||||
|
||||
@ -172,6 +172,7 @@ class _TestPoolingBase(AsyncIntegrationTest):
|
||||
kwargs["server_api"] = pool_options.server_api
|
||||
pool = Pool(pair, PoolOptions(*args, **kwargs))
|
||||
await pool.ready()
|
||||
self.addAsyncCleanup(pool.close)
|
||||
return pool
|
||||
|
||||
|
||||
|
||||
@ -172,6 +172,7 @@ class _TestPoolingBase(IntegrationTest):
|
||||
kwargs["server_api"] = pool_options.server_api
|
||||
pool = Pool(pair, PoolOptions(*args, **kwargs))
|
||||
pool.ready()
|
||||
self.addCleanup(pool.close)
|
||||
return pool
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user