This reverts commit 71cbde8ba4.
This commit is contained in:
parent
b65bce5924
commit
86a0eb0268
@ -4,10 +4,9 @@ import ssl
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from httpx.config import PoolLimits, TimeoutConfig
|
||||
from httpx.exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
|
||||
|
||||
from ..base import (
|
||||
from ..config import PoolLimits, TimeoutConfig
|
||||
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
|
||||
from .base import (
|
||||
BaseBackgroundManager,
|
||||
BaseEvent,
|
||||
BasePoolSemaphore,
|
||||
@ -16,7 +15,6 @@ from ..base import (
|
||||
ConcurrencyBackend,
|
||||
TimeoutFlag,
|
||||
)
|
||||
from .compat import Stream, connect_compat
|
||||
|
||||
SSL_MONKEY_PATCH_APPLIED = False
|
||||
|
||||
@ -43,12 +41,18 @@ def ssl_monkey_patch() -> None:
|
||||
|
||||
|
||||
class TCPStream(BaseTCPStream):
|
||||
def __init__(self, stream: Stream, timeout: TimeoutConfig):
|
||||
self.stream = stream
|
||||
def __init__(
|
||||
self,
|
||||
stream_reader: asyncio.StreamReader,
|
||||
stream_writer: asyncio.StreamWriter,
|
||||
timeout: TimeoutConfig,
|
||||
):
|
||||
self.stream_reader = stream_reader
|
||||
self.stream_writer = stream_writer
|
||||
self.timeout = timeout
|
||||
|
||||
def get_http_version(self) -> str:
|
||||
ssl_object = self.stream.get_extra_info("ssl_object")
|
||||
ssl_object = self.stream_writer.get_extra_info("ssl_object")
|
||||
|
||||
if ssl_object is None:
|
||||
return "HTTP/1.1"
|
||||
@ -68,7 +72,7 @@ class TCPStream(BaseTCPStream):
|
||||
should_raise = flag is None or flag.raise_on_read_timeout
|
||||
read_timeout = timeout.read_timeout if should_raise else 0.01
|
||||
try:
|
||||
data = await asyncio.wait_for(self.stream.read(n), read_timeout)
|
||||
data = await asyncio.wait_for(self.stream_reader.read(n), read_timeout)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
if should_raise:
|
||||
@ -83,7 +87,7 @@ class TCPStream(BaseTCPStream):
|
||||
return data
|
||||
|
||||
def write_no_block(self, data: bytes) -> None:
|
||||
self.stream.write(data) # pragma: nocover
|
||||
self.stream_writer.write(data) # pragma: nocover
|
||||
|
||||
async def write(
|
||||
self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
|
||||
@ -94,11 +98,11 @@ class TCPStream(BaseTCPStream):
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
self.stream.write(data)
|
||||
self.stream_writer.write(data)
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for( # type: ignore
|
||||
self.stream.drain(), timeout.write_timeout
|
||||
self.stream_writer.drain(), timeout.write_timeout
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
@ -124,12 +128,10 @@ class TCPStream(BaseTCPStream):
|
||||
# (For a solution that uses private asyncio APIs, see:
|
||||
# https://github.com/encode/httpx/pull/143#issuecomment-515202982)
|
||||
|
||||
return self.stream.at_eof()
|
||||
return self.stream_reader.at_eof()
|
||||
|
||||
async def close(self) -> None:
|
||||
# FIXME: We should await on this call, but need a workaround for this first:
|
||||
# https://github.com/aio-libs/aiohttp/issues/3535
|
||||
self.stream.close()
|
||||
self.stream_writer.close()
|
||||
|
||||
|
||||
class PoolSemaphore(BasePoolSemaphore):
|
||||
@ -188,13 +190,16 @@ class AsyncioBackend(ConcurrencyBackend):
|
||||
timeout: TimeoutConfig,
|
||||
) -> BaseTCPStream:
|
||||
try:
|
||||
stream = await asyncio.wait_for( # type: ignore
|
||||
connect_compat(hostname, port, ssl=ssl_context), timeout.connect_timeout
|
||||
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
|
||||
asyncio.open_connection(hostname, port, ssl=ssl_context),
|
||||
timeout.connect_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise ConnectTimeout()
|
||||
|
||||
return TCPStream(stream=stream, timeout=timeout)
|
||||
return TCPStream(
|
||||
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
|
||||
)
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
@ -203,13 +208,35 @@ class AsyncioBackend(ConcurrencyBackend):
|
||||
ssl_context: ssl.SSLContext,
|
||||
timeout: TimeoutConfig,
|
||||
) -> BaseTCPStream:
|
||||
|
||||
loop = self.loop
|
||||
if not hasattr(loop, "start_tls"): # pragma: no cover
|
||||
raise NotImplementedError(
|
||||
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
|
||||
)
|
||||
|
||||
assert isinstance(stream, TCPStream)
|
||||
|
||||
await asyncio.wait_for(
|
||||
stream.stream.start_tls(ssl_context, server_hostname=hostname),
|
||||
stream_reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(stream_reader)
|
||||
transport = stream.stream_writer.transport
|
||||
|
||||
loop_start_tls = loop.start_tls # type: ignore
|
||||
transport = await asyncio.wait_for(
|
||||
loop_start_tls(
|
||||
transport=transport,
|
||||
protocol=protocol,
|
||||
sslcontext=ssl_context,
|
||||
server_hostname=hostname,
|
||||
),
|
||||
timeout=timeout.connect_timeout,
|
||||
)
|
||||
|
||||
stream_reader.set_transport(transport)
|
||||
stream.stream_reader = stream_reader
|
||||
stream.stream_writer = asyncio.StreamWriter(
|
||||
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
|
||||
)
|
||||
return stream
|
||||
|
||||
async def run_in_threadpool(
|
||||
@ -1,3 +0,0 @@
|
||||
from .backend import AsyncioBackend, BackgroundManager, PoolSemaphore, TCPStream
|
||||
|
||||
__all__ = ["AsyncioBackend", "BackgroundManager", "PoolSemaphore", "TCPStream"]
|
||||
@ -1,137 +0,0 @@
|
||||
import asyncio
|
||||
import ssl
|
||||
import sys
|
||||
import typing
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Protocol
|
||||
else:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
|
||||
class Stream(Protocol): # pragma: no cover
|
||||
"""Protocol defining just the methods we use from asyncio.Stream."""
|
||||
|
||||
def at_eof(self) -> bool:
|
||||
...
|
||||
|
||||
def close(self) -> typing.Awaitable[None]:
|
||||
...
|
||||
|
||||
async def drain(self) -> None:
|
||||
...
|
||||
|
||||
def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any:
|
||||
...
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
...
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
sslContext: ssl.SSLContext,
|
||||
*,
|
||||
server_hostname: typing.Optional[str] = None,
|
||||
ssl_handshake_timeout: typing.Optional[float] = None,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def write(self, data: bytes) -> typing.Awaitable[None]:
|
||||
...
|
||||
|
||||
|
||||
async def connect_compat(*args: typing.Any, **kwargs: typing.Any) -> Stream:
|
||||
if sys.version_info >= (3, 8):
|
||||
return await asyncio.connect(*args, **kwargs)
|
||||
else:
|
||||
reader, writer = await asyncio.open_connection(*args, **kwargs)
|
||||
return StreamCompat(reader, writer)
|
||||
|
||||
|
||||
class StreamCompat:
|
||||
"""
|
||||
Thin wrapper around asyncio.StreamReader/StreamWriter to make them look and
|
||||
behave similarly to an asyncio.Stream.
|
||||
"""
|
||||
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
|
||||
def at_eof(self) -> bool:
|
||||
return self.reader.at_eof()
|
||||
|
||||
def close(self) -> typing.Awaitable[None]:
|
||||
self.writer.close()
|
||||
return _OptionalAwait(self.wait_closed)
|
||||
|
||||
async def drain(self) -> None:
|
||||
await self.writer.drain()
|
||||
|
||||
def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any:
|
||||
return self.writer.get_extra_info(name, default)
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
return await self.reader.read(n)
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
sslContext: ssl.SSLContext,
|
||||
*,
|
||||
server_hostname: typing.Optional[str] = None,
|
||||
ssl_handshake_timeout: typing.Optional[float] = None,
|
||||
) -> None:
|
||||
if not sys.version_info >= (3, 7): # pragma: no cover
|
||||
raise NotImplementedError(
|
||||
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
|
||||
)
|
||||
else:
|
||||
# This code is in an else branch to appease mypy on Python < 3.7
|
||||
|
||||
reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(reader)
|
||||
transport = self.writer.transport
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop_start_tls = loop.start_tls # type: ignore
|
||||
tls_transport = await loop_start_tls(
|
||||
transport=transport,
|
||||
protocol=protocol,
|
||||
sslcontext=sslContext,
|
||||
server_hostname=server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
)
|
||||
|
||||
reader.set_transport(tls_transport)
|
||||
self.reader = reader
|
||||
self.writer = asyncio.StreamWriter(
|
||||
transport=tls_transport, protocol=protocol, reader=reader, loop=loop
|
||||
)
|
||||
|
||||
def write(self, data: bytes) -> typing.Awaitable[None]:
|
||||
self.writer.write(data)
|
||||
return _OptionalAwait(self.drain)
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
if sys.version_info >= (3, 7):
|
||||
await self.writer.wait_closed()
|
||||
# else not much we can do to wait for the connection to close
|
||||
|
||||
|
||||
# This code is copied from cPython 3.8 but with type annotations added:
|
||||
# https://github.com/python/cpython/blob/v3.8.0b4/Lib/asyncio/streams.py#L1262-L1273
|
||||
_T = typing.TypeVar("_T")
|
||||
|
||||
|
||||
class _OptionalAwait(typing.Generic[_T]):
|
||||
# The class doesn't create a coroutine
|
||||
# if not awaited
|
||||
# It prevents "coroutine is never awaited" message
|
||||
|
||||
__slots___ = ("_method",)
|
||||
|
||||
def __init__(self, method: typing.Callable[[], typing.Awaitable[_T]]):
|
||||
self._method = method
|
||||
|
||||
def __await__(self) -> typing.Generator[typing.Any, None, _T]:
|
||||
return self._method().__await__()
|
||||
@ -25,11 +25,11 @@ async def test_start_tls_on_socket_stream(https_server):
|
||||
|
||||
try:
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert stream.stream.get_extra_info("cipher", default=None) is None
|
||||
assert stream.stream_writer.get_extra_info("cipher", default=None) is None
|
||||
|
||||
stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout)
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert stream.stream.get_extra_info("cipher", default=None) is not None
|
||||
assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
|
||||
|
||||
await stream.write(b"GET / HTTP/1.1\r\n\r\n")
|
||||
assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user