Revert "Use Python 3.8 asyncio.Stream where possible (#369)" (#423)

This reverts commit 71cbde8ba4.
This commit is contained in:
Josep Cugat 2019-10-03 10:18:10 +02:00 committed by Tom Christie
parent b65bce5924
commit 86a0eb0268
4 changed files with 50 additions and 163 deletions

View File

@ -4,10 +4,9 @@ import ssl
import typing
from types import TracebackType
from httpx.config import PoolLimits, TimeoutConfig
from httpx.exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from ..base import (
from ..config import PoolLimits, TimeoutConfig
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
BaseBackgroundManager,
BaseEvent,
BasePoolSemaphore,
@ -16,7 +15,6 @@ from ..base import (
ConcurrencyBackend,
TimeoutFlag,
)
from .compat import Stream, connect_compat
SSL_MONKEY_PATCH_APPLIED = False
@ -43,12 +41,18 @@ def ssl_monkey_patch() -> None:
class TCPStream(BaseTCPStream):
def __init__(self, stream: Stream, timeout: TimeoutConfig):
self.stream = stream
def __init__(
self,
stream_reader: asyncio.StreamReader,
stream_writer: asyncio.StreamWriter,
timeout: TimeoutConfig,
):
self.stream_reader = stream_reader
self.stream_writer = stream_writer
self.timeout = timeout
def get_http_version(self) -> str:
ssl_object = self.stream.get_extra_info("ssl_object")
ssl_object = self.stream_writer.get_extra_info("ssl_object")
if ssl_object is None:
return "HTTP/1.1"
@ -68,7 +72,7 @@ class TCPStream(BaseTCPStream):
should_raise = flag is None or flag.raise_on_read_timeout
read_timeout = timeout.read_timeout if should_raise else 0.01
try:
data = await asyncio.wait_for(self.stream.read(n), read_timeout)
data = await asyncio.wait_for(self.stream_reader.read(n), read_timeout)
break
except asyncio.TimeoutError:
if should_raise:
@ -83,7 +87,7 @@ class TCPStream(BaseTCPStream):
return data
def write_no_block(self, data: bytes) -> None:
self.stream.write(data) # pragma: nocover
self.stream_writer.write(data) # pragma: nocover
async def write(
self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
@ -94,11 +98,11 @@ class TCPStream(BaseTCPStream):
if timeout is None:
timeout = self.timeout
self.stream.write(data)
self.stream_writer.write(data)
while True:
try:
await asyncio.wait_for( # type: ignore
self.stream.drain(), timeout.write_timeout
self.stream_writer.drain(), timeout.write_timeout
)
break
except asyncio.TimeoutError:
@ -124,12 +128,10 @@ class TCPStream(BaseTCPStream):
# (For a solution that uses private asyncio APIs, see:
# https://github.com/encode/httpx/pull/143#issuecomment-515202982)
return self.stream.at_eof()
return self.stream_reader.at_eof()
async def close(self) -> None:
# FIXME: We should await on this call, but need a workaround for this first:
# https://github.com/aio-libs/aiohttp/issues/3535
self.stream.close()
self.stream_writer.close()
class PoolSemaphore(BasePoolSemaphore):
@ -188,13 +190,16 @@ class AsyncioBackend(ConcurrencyBackend):
timeout: TimeoutConfig,
) -> BaseTCPStream:
try:
stream = await asyncio.wait_for( # type: ignore
connect_compat(hostname, port, ssl=ssl_context), timeout.connect_timeout
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
timeout.connect_timeout,
)
except asyncio.TimeoutError:
raise ConnectTimeout()
return TCPStream(stream=stream, timeout=timeout)
return TCPStream(
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)
async def start_tls(
self,
@ -203,13 +208,35 @@ class AsyncioBackend(ConcurrencyBackend):
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseTCPStream:
loop = self.loop
if not hasattr(loop, "start_tls"): # pragma: no cover
raise NotImplementedError(
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
)
assert isinstance(stream, TCPStream)
await asyncio.wait_for(
stream.stream.start_tls(ssl_context, server_hostname=hostname),
stream_reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(stream_reader)
transport = stream.stream_writer.transport
loop_start_tls = loop.start_tls # type: ignore
transport = await asyncio.wait_for(
loop_start_tls(
transport=transport,
protocol=protocol,
sslcontext=ssl_context,
server_hostname=hostname,
),
timeout=timeout.connect_timeout,
)
stream_reader.set_transport(transport)
stream.stream_reader = stream_reader
stream.stream_writer = asyncio.StreamWriter(
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
)
return stream
async def run_in_threadpool(

View File

@ -1,3 +0,0 @@
from .backend import AsyncioBackend, BackgroundManager, PoolSemaphore, TCPStream
__all__ = ["AsyncioBackend", "BackgroundManager", "PoolSemaphore", "TCPStream"]

View File

@ -1,137 +0,0 @@
import asyncio
import ssl
import sys
import typing
if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol
class Stream(Protocol): # pragma: no cover
"""Protocol defining just the methods we use from asyncio.Stream."""
def at_eof(self) -> bool:
...
def close(self) -> typing.Awaitable[None]:
...
async def drain(self) -> None:
...
def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any:
...
async def read(self, n: int = -1) -> bytes:
...
async def start_tls(
self,
sslContext: ssl.SSLContext,
*,
server_hostname: typing.Optional[str] = None,
ssl_handshake_timeout: typing.Optional[float] = None,
) -> None:
...
def write(self, data: bytes) -> typing.Awaitable[None]:
...
async def connect_compat(*args: typing.Any, **kwargs: typing.Any) -> Stream:
if sys.version_info >= (3, 8):
return await asyncio.connect(*args, **kwargs)
else:
reader, writer = await asyncio.open_connection(*args, **kwargs)
return StreamCompat(reader, writer)
class StreamCompat:
"""
Thin wrapper around asyncio.StreamReader/StreamWriter to make them look and
behave similarly to an asyncio.Stream.
"""
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
self.reader = reader
self.writer = writer
def at_eof(self) -> bool:
return self.reader.at_eof()
def close(self) -> typing.Awaitable[None]:
self.writer.close()
return _OptionalAwait(self.wait_closed)
async def drain(self) -> None:
await self.writer.drain()
def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any:
return self.writer.get_extra_info(name, default)
async def read(self, n: int = -1) -> bytes:
return await self.reader.read(n)
async def start_tls(
self,
sslContext: ssl.SSLContext,
*,
server_hostname: typing.Optional[str] = None,
ssl_handshake_timeout: typing.Optional[float] = None,
) -> None:
if not sys.version_info >= (3, 7): # pragma: no cover
raise NotImplementedError(
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
)
else:
# This code is in an else branch to appease mypy on Python < 3.7
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
transport = self.writer.transport
loop = asyncio.get_event_loop()
loop_start_tls = loop.start_tls # type: ignore
tls_transport = await loop_start_tls(
transport=transport,
protocol=protocol,
sslcontext=sslContext,
server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
)
reader.set_transport(tls_transport)
self.reader = reader
self.writer = asyncio.StreamWriter(
transport=tls_transport, protocol=protocol, reader=reader, loop=loop
)
def write(self, data: bytes) -> typing.Awaitable[None]:
self.writer.write(data)
return _OptionalAwait(self.drain)
async def wait_closed(self) -> None:
if sys.version_info >= (3, 7):
await self.writer.wait_closed()
# else not much we can do to wait for the connection to close
# This code is copied from cPython 3.8 but with type annotations added:
# https://github.com/python/cpython/blob/v3.8.0b4/Lib/asyncio/streams.py#L1262-L1273
_T = typing.TypeVar("_T")
class _OptionalAwait(typing.Generic[_T]):
# The class doesn't create a coroutine
# if not awaited
# It prevents "coroutine is never awaited" message
__slots___ = ("_method",)
def __init__(self, method: typing.Callable[[], typing.Awaitable[_T]]):
self._method = method
def __await__(self) -> typing.Generator[typing.Any, None, _T]:
return self._method().__await__()

View File

@ -25,11 +25,11 @@ async def test_start_tls_on_socket_stream(https_server):
try:
assert stream.is_connection_dropped() is False
assert stream.stream.get_extra_info("cipher", default=None) is None
assert stream.stream_writer.get_extra_info("cipher", default=None) is None
stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout)
assert stream.is_connection_dropped() is False
assert stream.stream.get_extra_info("cipher", default=None) is not None
assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
await stream.write(b"GET / HTTP/1.1\r\n\r\n")
assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")