diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 6087b1aa8..f1c378b9b 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -18,32 +18,111 @@ from __future__ import annotations import asyncio import socket import struct +import sys +from asyncio import AbstractEventLoop, Future from typing import ( - TYPE_CHECKING, Union, ) from pymongo import ssl_support -if TYPE_CHECKING: - from pymongo.pyopenssl_context import _sslConn +try: + from ssl import SSLError, SSLSocket + + _HAVE_SSL = True +except ImportError: + _HAVE_SSL = False + +try: + from pymongo.pyopenssl_context import ( + BLOCKING_IO_LOOKUP_ERROR, + BLOCKING_IO_READ_ERROR, + BLOCKING_IO_WRITE_ERROR, + _sslConn, + ) + + _HAVE_PYOPENSSL = True +except ImportError: + _HAVE_PYOPENSSL = False + _sslConn = SSLSocket # type: ignore + from pymongo.ssl_support import ( # type: ignore[assignment] + BLOCKING_IO_LOOKUP_ERROR, + BLOCKING_IO_READ_ERROR, + BLOCKING_IO_WRITE_ERROR, + ) _UNPACK_HEADER = struct.Struct(" None: - timeout = socket.gettimeout() - socket.settimeout(0.0) +async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: + timeout = sock.gettimeout() + sock.settimeout(0.0) loop = asyncio.get_event_loop() try: - await asyncio.wait_for(loop.sock_sendall(socket, buf), timeout=timeout) # type: ignore[arg-type] + if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): + if sys.platform == "win32": + await asyncio.wait_for(_async_sendall_ssl_windows(sock, buf), timeout=timeout) + else: + await asyncio.wait_for(_async_sendall_ssl(sock, buf, loop), timeout=timeout) + else: + await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type] finally: - socket.settimeout(timeout) + sock.settimeout(timeout) -def sendall(socket: Union[socket.socket, _sslConn], buf: bytes) -> None: - socket.sendall(buf) +async def _async_sendall_ssl( + sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop +) -> None: + fd = sock.fileno() + sent = 0 + + def _is_ready(fut: Future) -> None: + loop.remove_writer(fd) + loop.remove_reader(fd) + if fut.done(): + return + fut.set_result(None) + + while sent < len(buf): + try: + sent += sock.send(buf) + except BLOCKING_IO_ERRORS as exc: + fd = sock.fileno() + # Check for closed socket. + if fd == -1: + raise SSLError("Underlying socket has been closed") from None + if isinstance(exc, BLOCKING_IO_READ_ERROR): + fut = loop.create_future() + loop.add_reader(fd, _is_ready, fut) + await fut + if isinstance(exc, BLOCKING_IO_WRITE_ERROR): + fut = loop.create_future() + loop.add_writer(fd, _is_ready, fut) + await fut + if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): + fut = loop.create_future() + loop.add_reader(fd, _is_ready, fut) + loop.add_writer(fd, _is_ready, fut) + await fut + + +# The default Windows asyncio event loop does not support loop.add_reader/add_writer: https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support +async def _async_sendall_ssl_windows(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: + view = memoryview(buf) + total_length = len(buf) + total_sent = 0 + while total_sent < total_length: + try: + sent = sock.send(view[total_sent:]) + except BLOCKING_IO_ERRORS: + await asyncio.sleep(0.5) + sent = 0 + total_sent += sent + + +def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: + sock.sendall(buf) diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index c1b85af12..4f6f6f4a8 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -90,6 +90,9 @@ def _is_ip_address(address: Any) -> bool: # According to the docs for socket.send it can raise # WantX509LookupError and should be retried. BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError) +BLOCKING_IO_READ_ERROR = _SSL.WantReadError +BLOCKING_IO_WRITE_ERROR = _SSL.WantWriteError +BLOCKING_IO_LOOKUP_ERROR = _SSL.WantX509LookupError def _ragged_eof(exc: BaseException) -> bool: diff --git a/pymongo/ssl_context.py b/pymongo/ssl_context.py index 1a0424208..ee32145c0 100644 --- a/pymongo/ssl_context.py +++ b/pymongo/ssl_context.py @@ -30,6 +30,8 @@ IS_PYOPENSSL = False # Errors raised by SSL sockets when in non-blocking mode. BLOCKING_IO_ERRORS = (_ssl.SSLWantReadError, _ssl.SSLWantWriteError) +BLOCKING_IO_READ_ERROR = _ssl.SSLWantReadError +BLOCKING_IO_WRITE_ERROR = _ssl.SSLWantWriteError # Base Exception class SSLError = _ssl.SSLError diff --git a/pymongo/ssl_support.py b/pymongo/ssl_support.py index 6a5dd278d..580d71f9b 100644 --- a/pymongo/ssl_support.py +++ b/pymongo/ssl_support.py @@ -53,6 +53,9 @@ if HAVE_SSL: IPADDR_SAFE = True SSLError = _ssl.SSLError BLOCKING_IO_ERRORS = _ssl.BLOCKING_IO_ERRORS + BLOCKING_IO_READ_ERROR = _ssl.BLOCKING_IO_READ_ERROR + BLOCKING_IO_WRITE_ERROR = _ssl.BLOCKING_IO_WRITE_ERROR + BLOCKING_IO_LOOKUP_ERROR = BLOCKING_IO_READ_ERROR def get_ssl_context( certfile: Optional[str],