PYTHON-4636 - Avoid blocking I/O calls in async code paths (#1870)

Co-authored-by: Shane Harvey <shnhrv@gmail.com>
This commit is contained in:
Noah Stapp 2024-10-03 15:18:33 -04:00 committed by GitHub
parent 7380097dbc
commit b111cbf5d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 248 additions and 166 deletions

View File

@ -15,11 +15,8 @@
"""Internal network layer helper methods."""
from __future__ import annotations
import asyncio
import datetime
import errno
import logging
import socket
import time
from typing import (
TYPE_CHECKING,
@ -40,19 +37,16 @@ from pymongo.errors import (
NotPrimaryError,
OperationFailure,
ProtocolError,
_OperationCancelled,
)
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.monitoring import _is_speculative_authenticate
from pymongo.network_layer import (
_POLL_TIMEOUT,
_UNPACK_COMPRESSION_HEADER,
_UNPACK_HEADER,
BLOCKING_IO_ERRORS,
async_receive_data,
async_sendall,
)
from pymongo.socket_checker import _errno_from_exception
if TYPE_CHECKING:
from bson import CodecOptions
@ -318,9 +312,7 @@ async def receive_message(
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(
await _receive_data_on_socket(conn, 16, deadline)
)
length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
@ -336,11 +328,11 @@ async def receive_message(
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
await _receive_data_on_socket(conn, 9, deadline)
await async_receive_data(conn, 9, deadline)
)
data = decompress(await _receive_data_on_socket(conn, length - 25, deadline), compressor_id)
data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id)
else:
data = await _receive_data_on_socket(conn, length - 16, deadline)
data = await async_receive_data(conn, length - 16, deadline)
try:
unpack_reply = _UNPACK_REPLY[op_code]
@ -349,66 +341,3 @@ async def receive_message(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)
async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
timed_out = False
# Check if the connection's socket has been manually closed
if sock.fileno() == -1:
return
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
readable = True
else:
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled")
if readable:
return
if timed_out:
raise socket.timeout("timed out")
await asyncio.sleep(0)
async def _receive_data_on_socket(
conn: AsyncConnection, length: int, deadline: Optional[float]
) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
await wait_for_read(conn, deadline)
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
if _csot.get_timeout() and deadline is not None:
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except OSError as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue
raise
if chunk_length == 0:
raise OSError("connection closed")
bytes_read += chunk_length
return mv

View File

@ -16,15 +16,21 @@
from __future__ import annotations
import asyncio
import errno
import socket
import struct
import sys
import time
from asyncio import AbstractEventLoop, Future
from typing import (
TYPE_CHECKING,
Optional,
Union,
)
from pymongo import ssl_support
from pymongo import _csot, ssl_support
from pymongo.errors import _OperationCancelled
from pymongo.socket_checker import _errno_from_exception
try:
from ssl import SSLError, SSLSocket
@ -51,6 +57,10 @@ except ImportError:
BLOCKING_IO_WRITE_ERROR,
)
if TYPE_CHECKING:
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.synchronous.pool import Connection
_UNPACK_HEADER = struct.Struct("<iiii").unpack
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
_POLL_TIMEOUT = 0.5
@ -80,12 +90,9 @@ if sys.platform != "win32":
sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop
) -> None:
view = memoryview(buf)
fd = sock.fileno()
sent = 0
def _is_ready(fut: Future) -> None:
loop.remove_writer(fd)
loop.remove_reader(fd)
if fut.done():
return
fut.set_result(None)
@ -101,33 +108,240 @@ if sys.platform != "win32":
if isinstance(exc, BLOCKING_IO_READ_ERROR):
fut = loop.create_future()
loop.add_reader(fd, _is_ready, fut)
await fut
try:
await fut
finally:
loop.remove_reader(fd)
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
fut = loop.create_future()
loop.add_writer(fd, _is_ready, fut)
await fut
try:
await fut
finally:
loop.remove_writer(fd)
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
fut = loop.create_future()
loop.add_reader(fd, _is_ready, fut)
try:
loop.add_writer(fd, _is_ready, fut)
await fut
finally:
loop.remove_reader(fd)
loop.remove_writer(fd)
async def _async_receive_ssl(
conn: _sslConn, length: int, loop: AbstractEventLoop
) -> memoryview:
mv = memoryview(bytearray(length))
total_read = 0
def _is_ready(fut: Future) -> None:
if fut.done():
return
fut.set_result(None)
while total_read < length:
try:
read = conn.recv_into(mv[total_read:])
if read == 0:
raise OSError("connection closed")
total_read += read
except BLOCKING_IO_ERRORS as exc:
fd = conn.fileno()
# Check for closed socket.
if fd == -1:
raise SSLError("Underlying socket has been closed") from None
if isinstance(exc, BLOCKING_IO_READ_ERROR):
fut = loop.create_future()
loop.add_reader(fd, _is_ready, fut)
try:
await fut
finally:
loop.remove_reader(fd)
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
fut = loop.create_future()
loop.add_writer(fd, _is_ready, fut)
await fut
try:
await fut
finally:
loop.remove_writer(fd)
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
fut = loop.create_future()
loop.add_reader(fd, _is_ready, fut)
try:
loop.add_writer(fd, _is_ready, fut)
await fut
finally:
loop.remove_reader(fd)
loop.remove_writer(fd)
return mv
else:
# The default Windows asyncio event loop does not support loop.add_reader/add_writer:
# https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support
# Note: In PYTHON-4493 we plan to replace this code with asyncio streams.
async def _async_sendall_ssl(
sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop
) -> None:
view = memoryview(buf)
total_length = len(buf)
total_sent = 0
# Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success
# down to 1ms.
backoff = 0.001
while total_sent < total_length:
try:
sent = sock.send(view[total_sent:])
except BLOCKING_IO_ERRORS:
await asyncio.sleep(0.5)
await asyncio.sleep(backoff)
sent = 0
if sent > 0:
backoff = max(backoff / 2, 0.001)
else:
backoff = min(backoff * 2, 0.512)
total_sent += sent
async def _async_receive_ssl(
conn: _sslConn, length: int, dummy: AbstractEventLoop
) -> memoryview:
mv = memoryview(bytearray(length))
total_read = 0
# Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success
# down to 1ms.
backoff = 0.001
while total_read < length:
try:
read = conn.recv_into(mv[total_read:])
if read == 0:
raise OSError("connection closed")
except BLOCKING_IO_ERRORS:
await asyncio.sleep(backoff)
read = 0
if read > 0:
backoff = max(backoff / 2, 0.001)
else:
backoff = min(backoff * 2, 0.512)
total_read += read
return mv
def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
sock.sendall(buf)
async def _poll_cancellation(conn: AsyncConnection) -> None:
while True:
if conn.cancel_context.cancelled:
return
await asyncio.sleep(_POLL_TIMEOUT)
async def async_receive_data(
conn: AsyncConnection, length: int, deadline: Optional[float]
) -> memoryview:
sock = conn.conn
sock_timeout = sock.gettimeout()
timeout: Optional[Union[float, int]]
if deadline:
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
timeout = max(deadline - time.monotonic(), 0)
else:
timeout = sock_timeout
sock.settimeout(0.0)
loop = asyncio.get_event_loop()
cancellation_task = asyncio.create_task(_poll_cancellation(conn))
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
else:
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
tasks = [read_task, cancellation_task]
done, pending = await asyncio.wait(
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
await asyncio.wait(pending)
if len(done) == 0:
raise socket.timeout("timed out")
if read_task in done:
return read_task.result()
raise _OperationCancelled("operation cancelled")
finally:
sock.settimeout(sock_timeout)
async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
mv = memoryview(bytearray(length))
bytes_read = 0
while bytes_read < length:
chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:])
if chunk_length == 0:
raise OSError("connection closed")
bytes_read += chunk_length
return mv
# Sync version:
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
timed_out = False
# Check if the connection's socket has been manually closed
if sock.fileno() == -1:
return
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
readable = True
else:
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled")
if readable:
return
if timed_out:
raise socket.timeout("timed out")
def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
wait_for_read(conn, deadline)
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
if _csot.get_timeout() and deadline is not None:
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except OSError as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue
raise
if chunk_length == 0:
raise OSError("connection closed")
bytes_read += chunk_length
return mv

View File

@ -105,13 +105,19 @@ def _ragged_eof(exc: BaseException) -> bool:
# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets
class _sslConn(_SSL.Connection):
def __init__(
self, ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool
self,
ctx: _SSL.Context,
sock: Optional[_socket.socket],
suppress_ragged_eofs: bool,
is_async: bool = False,
):
self.socket_checker = _SocketChecker()
self.suppress_ragged_eofs = suppress_ragged_eofs
super().__init__(ctx, sock)
self._is_async = is_async
def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
is_async = kwargs.pop("allow_async", True) and self._is_async
timeout = self.gettimeout()
if timeout:
start = _time.monotonic()
@ -119,6 +125,8 @@ class _sslConn(_SSL.Connection):
try:
return call(*args, **kwargs)
except BLOCKING_IO_ERRORS as exc:
if is_async:
raise exc
# Check for closed socket.
if self.fileno() == -1:
if timeout and _time.monotonic() - start > timeout:
@ -139,6 +147,7 @@ class _sslConn(_SSL.Connection):
continue
def do_handshake(self, *args: Any, **kwargs: Any) -> None:
kwargs["allow_async"] = False
return self._call(super().do_handshake, *args, **kwargs)
def recv(self, *args: Any, **kwargs: Any) -> bytes:
@ -381,7 +390,7 @@ class SSLContext:
"""Wrap an existing Python socket connection and return a TLS socket
object.
"""
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs, True)
loop = asyncio.get_running_loop()
if session:
ssl_conn.set_session(session)

View File

@ -16,9 +16,7 @@
from __future__ import annotations
import datetime
import errno
import logging
import socket
import time
from typing import (
TYPE_CHECKING,
@ -39,19 +37,16 @@ from pymongo.errors import (
NotPrimaryError,
OperationFailure,
ProtocolError,
_OperationCancelled,
)
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.monitoring import _is_speculative_authenticate
from pymongo.network_layer import (
_POLL_TIMEOUT,
_UNPACK_COMPRESSION_HEADER,
_UNPACK_HEADER,
BLOCKING_IO_ERRORS,
receive_data,
sendall,
)
from pymongo.socket_checker import _errno_from_exception
if TYPE_CHECKING:
from bson import CodecOptions
@ -317,7 +312,7 @@ def receive_message(
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(_receive_data_on_socket(conn, 16, deadline))
length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
@ -332,12 +327,10 @@ def receive_message(
f"message size ({max_message_size!r})"
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
_receive_data_on_socket(conn, 9, deadline)
)
data = decompress(_receive_data_on_socket(conn, length - 25, deadline), compressor_id)
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
else:
data = _receive_data_on_socket(conn, length - 16, deadline)
data = receive_data(conn, length - 16, deadline)
try:
unpack_reply = _UNPACK_REPLY[op_code]
@ -346,63 +339,3 @@ def receive_message(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
timed_out = False
# Check if the connection's socket has been manually closed
if sock.fileno() == -1:
return
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
readable = True
else:
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled")
if readable:
return
if timed_out:
raise socket.timeout("timed out")
def _receive_data_on_socket(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
wait_for_read(conn, deadline)
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
if _csot.get_timeout() and deadline is not None:
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except OSError as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue
raise
if chunk_length == 0:
raise OSError("connection closed")
bytes_read += chunk_length
return mv

View File

@ -1713,6 +1713,7 @@ class TestClient(AsyncIntegrationTest):
# No error
await client.pymongo_test.test.find_one()
@async_client_context.require_sync
async def test_reset_during_update_pool(self):
client = await self.async_rs_or_single_client(minPoolSize=10)
await client.admin.command("ping")
@ -1737,10 +1738,7 @@ class TestClient(AsyncIntegrationTest):
await asyncio.sleep(0.001)
def run(self):
if _IS_SYNC:
self._run()
else:
asyncio.run(self._run())
self._run()
t = ResetPoolThread(pool)
t.start()

View File

@ -1671,6 +1671,7 @@ class TestClient(IntegrationTest):
# No error
client.pymongo_test.test.find_one()
@client_context.require_sync
def test_reset_during_update_pool(self):
client = self.rs_or_single_client(minPoolSize=10)
client.admin.command("ping")
@ -1695,10 +1696,7 @@ class TestClient(IntegrationTest):
time.sleep(0.001)
def run(self):
if _IS_SYNC:
self._run()
else:
asyncio.run(self._run())
self._run()
t = ResetPoolThread(pool)
t.start()

View File

@ -43,6 +43,7 @@ replacements = {
"AsyncConnection": "Connection",
"async_command": "command",
"async_receive_message": "receive_message",
"async_receive_data": "receive_data",
"async_sendall": "sendall",
"asynchronous": "synchronous",
"Asynchronous": "Synchronous",