PYTHON-4636 - Avoid blocking I/O calls in async code paths (#1870)
Co-authored-by: Shane Harvey <shnhrv@gmail.com>
This commit is contained in:
parent
7380097dbc
commit
b111cbf5d5
@ -15,11 +15,8 @@
|
||||
"""Internal network layer helper methods."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import errno
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -40,19 +37,16 @@ from pymongo.errors import (
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
ProtocolError,
|
||||
_OperationCancelled,
|
||||
)
|
||||
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
|
||||
from pymongo.monitoring import _is_speculative_authenticate
|
||||
from pymongo.network_layer import (
|
||||
_POLL_TIMEOUT,
|
||||
_UNPACK_COMPRESSION_HEADER,
|
||||
_UNPACK_HEADER,
|
||||
BLOCKING_IO_ERRORS,
|
||||
async_receive_data,
|
||||
async_sendall,
|
||||
)
|
||||
from pymongo.socket_checker import _errno_from_exception
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson import CodecOptions
|
||||
@ -318,9 +312,7 @@ async def receive_message(
|
||||
else:
|
||||
deadline = None
|
||||
# Ignore the response's request id.
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(
|
||||
await _receive_data_on_socket(conn, 16, deadline)
|
||||
)
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline))
|
||||
# No request_id for exhaust cursor "getMore".
|
||||
if request_id is not None:
|
||||
if request_id != response_to:
|
||||
@ -336,11 +328,11 @@ async def receive_message(
|
||||
)
|
||||
if op_code == 2012:
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
|
||||
await _receive_data_on_socket(conn, 9, deadline)
|
||||
await async_receive_data(conn, 9, deadline)
|
||||
)
|
||||
data = decompress(await _receive_data_on_socket(conn, length - 25, deadline), compressor_id)
|
||||
data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id)
|
||||
else:
|
||||
data = await _receive_data_on_socket(conn, length - 16, deadline)
|
||||
data = await async_receive_data(conn, length - 16, deadline)
|
||||
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
@ -349,66 +341,3 @@ async def receive_message(
|
||||
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
|
||||
) from None
|
||||
return unpack_reply(data)
|
||||
|
||||
|
||||
async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None:
|
||||
"""Block until at least one byte is read, or a timeout, or a cancel."""
|
||||
sock = conn.conn
|
||||
timed_out = False
|
||||
# Check if the connection's socket has been manually closed
|
||||
if sock.fileno() == -1:
|
||||
return
|
||||
while True:
|
||||
# SSLSocket can have buffered data which won't be caught by select.
|
||||
if hasattr(sock, "pending") and sock.pending() > 0:
|
||||
readable = True
|
||||
else:
|
||||
# Wait up to 500ms for the socket to become readable and then
|
||||
# check for cancellation.
|
||||
if deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
# When the timeout has expired perform one final check to
|
||||
# see if the socket is readable. This helps avoid spurious
|
||||
# timeouts on AWS Lambda and other FaaS environments.
|
||||
if remaining <= 0:
|
||||
timed_out = True
|
||||
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
|
||||
else:
|
||||
timeout = _POLL_TIMEOUT
|
||||
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
|
||||
if conn.cancel_context.cancelled:
|
||||
raise _OperationCancelled("operation cancelled")
|
||||
if readable:
|
||||
return
|
||||
if timed_out:
|
||||
raise socket.timeout("timed out")
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
async def _receive_data_on_socket(
|
||||
conn: AsyncConnection, length: int, deadline: Optional[float]
|
||||
) -> memoryview:
|
||||
buf = bytearray(length)
|
||||
mv = memoryview(buf)
|
||||
bytes_read = 0
|
||||
while bytes_read < length:
|
||||
try:
|
||||
await wait_for_read(conn, deadline)
|
||||
# CSOT: Update timeout. When the timeout has expired perform one
|
||||
# final non-blocking recv. This helps avoid spurious timeouts when
|
||||
# the response is actually already buffered on the client.
|
||||
if _csot.get_timeout() and deadline is not None:
|
||||
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
|
||||
chunk_length = conn.conn.recv_into(mv[bytes_read:])
|
||||
except BLOCKING_IO_ERRORS:
|
||||
raise socket.timeout("timed out") from None
|
||||
except OSError as exc:
|
||||
if _errno_from_exception(exc) == errno.EINTR:
|
||||
continue
|
||||
raise
|
||||
if chunk_length == 0:
|
||||
raise OSError("connection closed")
|
||||
|
||||
bytes_read += chunk_length
|
||||
|
||||
return mv
|
||||
|
||||
@ -16,15 +16,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
from asyncio import AbstractEventLoop, Future
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pymongo import ssl_support
|
||||
from pymongo import _csot, ssl_support
|
||||
from pymongo.errors import _OperationCancelled
|
||||
from pymongo.socket_checker import _errno_from_exception
|
||||
|
||||
try:
|
||||
from ssl import SSLError, SSLSocket
|
||||
@ -51,6 +57,10 @@ except ImportError:
|
||||
BLOCKING_IO_WRITE_ERROR,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
_UNPACK_HEADER = struct.Struct("<iiii").unpack
|
||||
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
|
||||
_POLL_TIMEOUT = 0.5
|
||||
@ -80,12 +90,9 @@ if sys.platform != "win32":
|
||||
sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop
|
||||
) -> None:
|
||||
view = memoryview(buf)
|
||||
fd = sock.fileno()
|
||||
sent = 0
|
||||
|
||||
def _is_ready(fut: Future) -> None:
|
||||
loop.remove_writer(fd)
|
||||
loop.remove_reader(fd)
|
||||
if fut.done():
|
||||
return
|
||||
fut.set_result(None)
|
||||
@ -101,33 +108,240 @@ if sys.platform != "win32":
|
||||
if isinstance(exc, BLOCKING_IO_READ_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_reader(fd, _is_ready, fut)
|
||||
await fut
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_reader(fd)
|
||||
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_writer(fd, _is_ready, fut)
|
||||
await fut
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_writer(fd)
|
||||
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_reader(fd, _is_ready, fut)
|
||||
try:
|
||||
loop.add_writer(fd, _is_ready, fut)
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_reader(fd)
|
||||
loop.remove_writer(fd)
|
||||
|
||||
async def _async_receive_ssl(
|
||||
conn: _sslConn, length: int, loop: AbstractEventLoop
|
||||
) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
total_read = 0
|
||||
|
||||
def _is_ready(fut: Future) -> None:
|
||||
if fut.done():
|
||||
return
|
||||
fut.set_result(None)
|
||||
|
||||
while total_read < length:
|
||||
try:
|
||||
read = conn.recv_into(mv[total_read:])
|
||||
if read == 0:
|
||||
raise OSError("connection closed")
|
||||
total_read += read
|
||||
except BLOCKING_IO_ERRORS as exc:
|
||||
fd = conn.fileno()
|
||||
# Check for closed socket.
|
||||
if fd == -1:
|
||||
raise SSLError("Underlying socket has been closed") from None
|
||||
if isinstance(exc, BLOCKING_IO_READ_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_reader(fd, _is_ready, fut)
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_reader(fd)
|
||||
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_writer(fd, _is_ready, fut)
|
||||
await fut
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_writer(fd)
|
||||
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_reader(fd, _is_ready, fut)
|
||||
try:
|
||||
loop.add_writer(fd, _is_ready, fut)
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_reader(fd)
|
||||
loop.remove_writer(fd)
|
||||
return mv
|
||||
|
||||
else:
|
||||
# The default Windows asyncio event loop does not support loop.add_reader/add_writer:
|
||||
# https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support
|
||||
# Note: In PYTHON-4493 we plan to replace this code with asyncio streams.
|
||||
async def _async_sendall_ssl(
|
||||
sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop
|
||||
) -> None:
|
||||
view = memoryview(buf)
|
||||
total_length = len(buf)
|
||||
total_sent = 0
|
||||
# Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success
|
||||
# down to 1ms.
|
||||
backoff = 0.001
|
||||
while total_sent < total_length:
|
||||
try:
|
||||
sent = sock.send(view[total_sent:])
|
||||
except BLOCKING_IO_ERRORS:
|
||||
await asyncio.sleep(0.5)
|
||||
await asyncio.sleep(backoff)
|
||||
sent = 0
|
||||
if sent > 0:
|
||||
backoff = max(backoff / 2, 0.001)
|
||||
else:
|
||||
backoff = min(backoff * 2, 0.512)
|
||||
total_sent += sent
|
||||
|
||||
async def _async_receive_ssl(
|
||||
conn: _sslConn, length: int, dummy: AbstractEventLoop
|
||||
) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
total_read = 0
|
||||
# Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success
|
||||
# down to 1ms.
|
||||
backoff = 0.001
|
||||
while total_read < length:
|
||||
try:
|
||||
read = conn.recv_into(mv[total_read:])
|
||||
if read == 0:
|
||||
raise OSError("connection closed")
|
||||
except BLOCKING_IO_ERRORS:
|
||||
await asyncio.sleep(backoff)
|
||||
read = 0
|
||||
if read > 0:
|
||||
backoff = max(backoff / 2, 0.001)
|
||||
else:
|
||||
backoff = min(backoff * 2, 0.512)
|
||||
total_read += read
|
||||
return mv
|
||||
|
||||
|
||||
def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
|
||||
sock.sendall(buf)
|
||||
|
||||
|
||||
async def _poll_cancellation(conn: AsyncConnection) -> None:
|
||||
while True:
|
||||
if conn.cancel_context.cancelled:
|
||||
return
|
||||
|
||||
await asyncio.sleep(_POLL_TIMEOUT)
|
||||
|
||||
|
||||
async def async_receive_data(
|
||||
conn: AsyncConnection, length: int, deadline: Optional[float]
|
||||
) -> memoryview:
|
||||
sock = conn.conn
|
||||
sock_timeout = sock.gettimeout()
|
||||
timeout: Optional[Union[float, int]]
|
||||
if deadline:
|
||||
# When the timeout has expired perform one final check to
|
||||
# see if the socket is readable. This helps avoid spurious
|
||||
# timeouts on AWS Lambda and other FaaS environments.
|
||||
timeout = max(deadline - time.monotonic(), 0)
|
||||
else:
|
||||
timeout = sock_timeout
|
||||
|
||||
sock.settimeout(0.0)
|
||||
loop = asyncio.get_event_loop()
|
||||
cancellation_task = asyncio.create_task(_poll_cancellation(conn))
|
||||
try:
|
||||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
||||
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
|
||||
else:
|
||||
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
|
||||
tasks = [read_task, cancellation_task]
|
||||
done, pending = await asyncio.wait(
|
||||
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
await asyncio.wait(pending)
|
||||
if len(done) == 0:
|
||||
raise socket.timeout("timed out")
|
||||
if read_task in done:
|
||||
return read_task.result()
|
||||
raise _OperationCancelled("operation cancelled")
|
||||
finally:
|
||||
sock.settimeout(sock_timeout)
|
||||
|
||||
|
||||
async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
bytes_read = 0
|
||||
while bytes_read < length:
|
||||
chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:])
|
||||
if chunk_length == 0:
|
||||
raise OSError("connection closed")
|
||||
bytes_read += chunk_length
|
||||
return mv
|
||||
|
||||
|
||||
# Sync version:
|
||||
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
|
||||
"""Block until at least one byte is read, or a timeout, or a cancel."""
|
||||
sock = conn.conn
|
||||
timed_out = False
|
||||
# Check if the connection's socket has been manually closed
|
||||
if sock.fileno() == -1:
|
||||
return
|
||||
while True:
|
||||
# SSLSocket can have buffered data which won't be caught by select.
|
||||
if hasattr(sock, "pending") and sock.pending() > 0:
|
||||
readable = True
|
||||
else:
|
||||
# Wait up to 500ms for the socket to become readable and then
|
||||
# check for cancellation.
|
||||
if deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
# When the timeout has expired perform one final check to
|
||||
# see if the socket is readable. This helps avoid spurious
|
||||
# timeouts on AWS Lambda and other FaaS environments.
|
||||
if remaining <= 0:
|
||||
timed_out = True
|
||||
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
|
||||
else:
|
||||
timeout = _POLL_TIMEOUT
|
||||
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
|
||||
if conn.cancel_context.cancelled:
|
||||
raise _OperationCancelled("operation cancelled")
|
||||
if readable:
|
||||
return
|
||||
if timed_out:
|
||||
raise socket.timeout("timed out")
|
||||
|
||||
|
||||
def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
|
||||
buf = bytearray(length)
|
||||
mv = memoryview(buf)
|
||||
bytes_read = 0
|
||||
while bytes_read < length:
|
||||
try:
|
||||
wait_for_read(conn, deadline)
|
||||
# CSOT: Update timeout. When the timeout has expired perform one
|
||||
# final non-blocking recv. This helps avoid spurious timeouts when
|
||||
# the response is actually already buffered on the client.
|
||||
if _csot.get_timeout() and deadline is not None:
|
||||
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
|
||||
chunk_length = conn.conn.recv_into(mv[bytes_read:])
|
||||
except BLOCKING_IO_ERRORS:
|
||||
raise socket.timeout("timed out") from None
|
||||
except OSError as exc:
|
||||
if _errno_from_exception(exc) == errno.EINTR:
|
||||
continue
|
||||
raise
|
||||
if chunk_length == 0:
|
||||
raise OSError("connection closed")
|
||||
|
||||
bytes_read += chunk_length
|
||||
|
||||
return mv
|
||||
|
||||
@ -105,13 +105,19 @@ def _ragged_eof(exc: BaseException) -> bool:
|
||||
# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets
|
||||
class _sslConn(_SSL.Connection):
|
||||
def __init__(
|
||||
self, ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool
|
||||
self,
|
||||
ctx: _SSL.Context,
|
||||
sock: Optional[_socket.socket],
|
||||
suppress_ragged_eofs: bool,
|
||||
is_async: bool = False,
|
||||
):
|
||||
self.socket_checker = _SocketChecker()
|
||||
self.suppress_ragged_eofs = suppress_ragged_eofs
|
||||
super().__init__(ctx, sock)
|
||||
self._is_async = is_async
|
||||
|
||||
def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
|
||||
is_async = kwargs.pop("allow_async", True) and self._is_async
|
||||
timeout = self.gettimeout()
|
||||
if timeout:
|
||||
start = _time.monotonic()
|
||||
@ -119,6 +125,8 @@ class _sslConn(_SSL.Connection):
|
||||
try:
|
||||
return call(*args, **kwargs)
|
||||
except BLOCKING_IO_ERRORS as exc:
|
||||
if is_async:
|
||||
raise exc
|
||||
# Check for closed socket.
|
||||
if self.fileno() == -1:
|
||||
if timeout and _time.monotonic() - start > timeout:
|
||||
@ -139,6 +147,7 @@ class _sslConn(_SSL.Connection):
|
||||
continue
|
||||
|
||||
def do_handshake(self, *args: Any, **kwargs: Any) -> None:
|
||||
kwargs["allow_async"] = False
|
||||
return self._call(super().do_handshake, *args, **kwargs)
|
||||
|
||||
def recv(self, *args: Any, **kwargs: Any) -> bytes:
|
||||
@ -381,7 +390,7 @@ class SSLContext:
|
||||
"""Wrap an existing Python socket connection and return a TLS socket
|
||||
object.
|
||||
"""
|
||||
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)
|
||||
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs, True)
|
||||
loop = asyncio.get_running_loop()
|
||||
if session:
|
||||
ssl_conn.set_session(session)
|
||||
|
||||
@ -16,9 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import errno
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -39,19 +37,16 @@ from pymongo.errors import (
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
ProtocolError,
|
||||
_OperationCancelled,
|
||||
)
|
||||
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
|
||||
from pymongo.monitoring import _is_speculative_authenticate
|
||||
from pymongo.network_layer import (
|
||||
_POLL_TIMEOUT,
|
||||
_UNPACK_COMPRESSION_HEADER,
|
||||
_UNPACK_HEADER,
|
||||
BLOCKING_IO_ERRORS,
|
||||
receive_data,
|
||||
sendall,
|
||||
)
|
||||
from pymongo.socket_checker import _errno_from_exception
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson import CodecOptions
|
||||
@ -317,7 +312,7 @@ def receive_message(
|
||||
else:
|
||||
deadline = None
|
||||
# Ignore the response's request id.
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(_receive_data_on_socket(conn, 16, deadline))
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
|
||||
# No request_id for exhaust cursor "getMore".
|
||||
if request_id is not None:
|
||||
if request_id != response_to:
|
||||
@ -332,12 +327,10 @@ def receive_message(
|
||||
f"message size ({max_message_size!r})"
|
||||
)
|
||||
if op_code == 2012:
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
|
||||
_receive_data_on_socket(conn, 9, deadline)
|
||||
)
|
||||
data = decompress(_receive_data_on_socket(conn, length - 25, deadline), compressor_id)
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
|
||||
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
|
||||
else:
|
||||
data = _receive_data_on_socket(conn, length - 16, deadline)
|
||||
data = receive_data(conn, length - 16, deadline)
|
||||
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
@ -346,63 +339,3 @@ def receive_message(
|
||||
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
|
||||
) from None
|
||||
return unpack_reply(data)
|
||||
|
||||
|
||||
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
|
||||
"""Block until at least one byte is read, or a timeout, or a cancel."""
|
||||
sock = conn.conn
|
||||
timed_out = False
|
||||
# Check if the connection's socket has been manually closed
|
||||
if sock.fileno() == -1:
|
||||
return
|
||||
while True:
|
||||
# SSLSocket can have buffered data which won't be caught by select.
|
||||
if hasattr(sock, "pending") and sock.pending() > 0:
|
||||
readable = True
|
||||
else:
|
||||
# Wait up to 500ms for the socket to become readable and then
|
||||
# check for cancellation.
|
||||
if deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
# When the timeout has expired perform one final check to
|
||||
# see if the socket is readable. This helps avoid spurious
|
||||
# timeouts on AWS Lambda and other FaaS environments.
|
||||
if remaining <= 0:
|
||||
timed_out = True
|
||||
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
|
||||
else:
|
||||
timeout = _POLL_TIMEOUT
|
||||
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
|
||||
if conn.cancel_context.cancelled:
|
||||
raise _OperationCancelled("operation cancelled")
|
||||
if readable:
|
||||
return
|
||||
if timed_out:
|
||||
raise socket.timeout("timed out")
|
||||
|
||||
|
||||
def _receive_data_on_socket(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
|
||||
buf = bytearray(length)
|
||||
mv = memoryview(buf)
|
||||
bytes_read = 0
|
||||
while bytes_read < length:
|
||||
try:
|
||||
wait_for_read(conn, deadline)
|
||||
# CSOT: Update timeout. When the timeout has expired perform one
|
||||
# final non-blocking recv. This helps avoid spurious timeouts when
|
||||
# the response is actually already buffered on the client.
|
||||
if _csot.get_timeout() and deadline is not None:
|
||||
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
|
||||
chunk_length = conn.conn.recv_into(mv[bytes_read:])
|
||||
except BLOCKING_IO_ERRORS:
|
||||
raise socket.timeout("timed out") from None
|
||||
except OSError as exc:
|
||||
if _errno_from_exception(exc) == errno.EINTR:
|
||||
continue
|
||||
raise
|
||||
if chunk_length == 0:
|
||||
raise OSError("connection closed")
|
||||
|
||||
bytes_read += chunk_length
|
||||
|
||||
return mv
|
||||
|
||||
@ -1713,6 +1713,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
# No error
|
||||
await client.pymongo_test.test.find_one()
|
||||
|
||||
@async_client_context.require_sync
|
||||
async def test_reset_during_update_pool(self):
|
||||
client = await self.async_rs_or_single_client(minPoolSize=10)
|
||||
await client.admin.command("ping")
|
||||
@ -1737,10 +1738,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
def run(self):
|
||||
if _IS_SYNC:
|
||||
self._run()
|
||||
else:
|
||||
asyncio.run(self._run())
|
||||
self._run()
|
||||
|
||||
t = ResetPoolThread(pool)
|
||||
t.start()
|
||||
|
||||
@ -1671,6 +1671,7 @@ class TestClient(IntegrationTest):
|
||||
# No error
|
||||
client.pymongo_test.test.find_one()
|
||||
|
||||
@client_context.require_sync
|
||||
def test_reset_during_update_pool(self):
|
||||
client = self.rs_or_single_client(minPoolSize=10)
|
||||
client.admin.command("ping")
|
||||
@ -1695,10 +1696,7 @@ class TestClient(IntegrationTest):
|
||||
time.sleep(0.001)
|
||||
|
||||
def run(self):
|
||||
if _IS_SYNC:
|
||||
self._run()
|
||||
else:
|
||||
asyncio.run(self._run())
|
||||
self._run()
|
||||
|
||||
t = ResetPoolThread(pool)
|
||||
t.start()
|
||||
|
||||
@ -43,6 +43,7 @@ replacements = {
|
||||
"AsyncConnection": "Connection",
|
||||
"async_command": "command",
|
||||
"async_receive_message": "receive_message",
|
||||
"async_receive_data": "receive_data",
|
||||
"async_sendall": "sendall",
|
||||
"asynchronous": "synchronous",
|
||||
"Asynchronous": "Synchronous",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user