PYTHON-4618 - Fix TypeError: Socket cannot be of type SSLSocket (#1772)
This commit is contained in:
parent
13cf110f01
commit
682f15b21e
@ -18,32 +18,111 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
from asyncio import AbstractEventLoop, Future
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pymongo import ssl_support
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.pyopenssl_context import _sslConn
|
||||
try:
|
||||
from ssl import SSLError, SSLSocket
|
||||
|
||||
_HAVE_SSL = True
|
||||
except ImportError:
|
||||
_HAVE_SSL = False
|
||||
|
||||
try:
|
||||
from pymongo.pyopenssl_context import (
|
||||
BLOCKING_IO_LOOKUP_ERROR,
|
||||
BLOCKING_IO_READ_ERROR,
|
||||
BLOCKING_IO_WRITE_ERROR,
|
||||
_sslConn,
|
||||
)
|
||||
|
||||
_HAVE_PYOPENSSL = True
|
||||
except ImportError:
|
||||
_HAVE_PYOPENSSL = False
|
||||
_sslConn = SSLSocket # type: ignore
|
||||
from pymongo.ssl_support import ( # type: ignore[assignment]
|
||||
BLOCKING_IO_LOOKUP_ERROR,
|
||||
BLOCKING_IO_READ_ERROR,
|
||||
BLOCKING_IO_WRITE_ERROR,
|
||||
)
|
||||
|
||||
_UNPACK_HEADER = struct.Struct("<iiii").unpack
|
||||
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
|
||||
_POLL_TIMEOUT = 0.5
|
||||
# Errors raised by sockets (and TLS sockets) when in non-blocking mode.
|
||||
BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS)
|
||||
BLOCKING_IO_ERRORS = (BlockingIOError, BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
|
||||
|
||||
|
||||
async def async_sendall(socket: Union[socket.socket, _sslConn], buf: bytes) -> None:
|
||||
timeout = socket.gettimeout()
|
||||
socket.settimeout(0.0)
|
||||
async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
|
||||
timeout = sock.gettimeout()
|
||||
sock.settimeout(0.0)
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
await asyncio.wait_for(loop.sock_sendall(socket, buf), timeout=timeout) # type: ignore[arg-type]
|
||||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
||||
if sys.platform == "win32":
|
||||
await asyncio.wait_for(_async_sendall_ssl_windows(sock, buf), timeout=timeout)
|
||||
else:
|
||||
await asyncio.wait_for(_async_sendall_ssl(sock, buf, loop), timeout=timeout)
|
||||
else:
|
||||
await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type]
|
||||
finally:
|
||||
socket.settimeout(timeout)
|
||||
sock.settimeout(timeout)
|
||||
|
||||
|
||||
def sendall(socket: Union[socket.socket, _sslConn], buf: bytes) -> None:
|
||||
socket.sendall(buf)
|
||||
async def _async_sendall_ssl(
|
||||
sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop
|
||||
) -> None:
|
||||
fd = sock.fileno()
|
||||
sent = 0
|
||||
|
||||
def _is_ready(fut: Future) -> None:
|
||||
loop.remove_writer(fd)
|
||||
loop.remove_reader(fd)
|
||||
if fut.done():
|
||||
return
|
||||
fut.set_result(None)
|
||||
|
||||
while sent < len(buf):
|
||||
try:
|
||||
sent += sock.send(buf)
|
||||
except BLOCKING_IO_ERRORS as exc:
|
||||
fd = sock.fileno()
|
||||
# Check for closed socket.
|
||||
if fd == -1:
|
||||
raise SSLError("Underlying socket has been closed") from None
|
||||
if isinstance(exc, BLOCKING_IO_READ_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_reader(fd, _is_ready, fut)
|
||||
await fut
|
||||
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_writer(fd, _is_ready, fut)
|
||||
await fut
|
||||
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_reader(fd, _is_ready, fut)
|
||||
loop.add_writer(fd, _is_ready, fut)
|
||||
await fut
|
||||
|
||||
|
||||
# The default Windows asyncio event loop does not support loop.add_reader/add_writer: https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support
|
||||
async def _async_sendall_ssl_windows(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
|
||||
view = memoryview(buf)
|
||||
total_length = len(buf)
|
||||
total_sent = 0
|
||||
while total_sent < total_length:
|
||||
try:
|
||||
sent = sock.send(view[total_sent:])
|
||||
except BLOCKING_IO_ERRORS:
|
||||
await asyncio.sleep(0.5)
|
||||
sent = 0
|
||||
total_sent += sent
|
||||
|
||||
|
||||
def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
|
||||
sock.sendall(buf)
|
||||
|
||||
@ -90,6 +90,9 @@ def _is_ip_address(address: Any) -> bool:
|
||||
# According to the docs for socket.send it can raise
|
||||
# WantX509LookupError and should be retried.
|
||||
BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError)
|
||||
BLOCKING_IO_READ_ERROR = _SSL.WantReadError
|
||||
BLOCKING_IO_WRITE_ERROR = _SSL.WantWriteError
|
||||
BLOCKING_IO_LOOKUP_ERROR = _SSL.WantX509LookupError
|
||||
|
||||
|
||||
def _ragged_eof(exc: BaseException) -> bool:
|
||||
|
||||
@ -30,6 +30,8 @@ IS_PYOPENSSL = False
|
||||
|
||||
# Errors raised by SSL sockets when in non-blocking mode.
|
||||
BLOCKING_IO_ERRORS = (_ssl.SSLWantReadError, _ssl.SSLWantWriteError)
|
||||
BLOCKING_IO_READ_ERROR = _ssl.SSLWantReadError
|
||||
BLOCKING_IO_WRITE_ERROR = _ssl.SSLWantWriteError
|
||||
|
||||
# Base Exception class
|
||||
SSLError = _ssl.SSLError
|
||||
|
||||
@ -53,6 +53,9 @@ if HAVE_SSL:
|
||||
IPADDR_SAFE = True
|
||||
SSLError = _ssl.SSLError
|
||||
BLOCKING_IO_ERRORS = _ssl.BLOCKING_IO_ERRORS
|
||||
BLOCKING_IO_READ_ERROR = _ssl.BLOCKING_IO_READ_ERROR
|
||||
BLOCKING_IO_WRITE_ERROR = _ssl.BLOCKING_IO_WRITE_ERROR
|
||||
BLOCKING_IO_LOOKUP_ERROR = BLOCKING_IO_READ_ERROR
|
||||
|
||||
def get_ssl_context(
|
||||
certfile: Optional[str],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user