Merge pull request #255 from encode/refactor/stream-interface

Unify BaseReader and BaseWriter as BaseStream
This commit is contained in:
Tom Christie 2019-08-21 09:38:39 +01:00 committed by GitHub
commit 30692272fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 75 additions and 90 deletions

View File

@ -5,8 +5,7 @@ from .concurrency.asyncio import AsyncioBackend
from .concurrency.base import (
BaseBackgroundManager,
BasePoolSemaphore,
BaseReader,
BaseWriter,
BaseStream,
ConcurrencyBackend,
)
from .config import (
@ -105,8 +104,7 @@ __all__ = [
"TooManyRedirects",
"WriteTimeout",
"AsyncDispatcher",
"BaseReader",
"BaseWriter",
"BaseStream",
"ConcurrencyBackend",
"Dispatcher",
"URL",

View File

@ -1,5 +1,5 @@
"""
The `Reader` and `Writer` classes here provide a lightweight layer over
The `Stream` class here provides a lightweight layer over
`asyncio.StreamReader` and `asyncio.StreamWriter`.
Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`.
@ -14,18 +14,17 @@ import ssl
import typing
from types import TracebackType
from ..config import PoolLimits, TimeoutConfig
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
BaseBackgroundManager,
BasePoolSemaphore,
BaseEvent,
BaseQueue,
BaseReader,
BaseWriter,
BaseStream,
ConcurrencyBackend,
TimeoutFlag,
)
from ..config import PoolLimits, TimeoutConfig
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
SSL_MONKEY_PATCH_APPLIED = False
@ -51,13 +50,29 @@ def ssl_monkey_patch() -> None:
MonkeyPatch.write = _fixed_write
class Reader(BaseReader):
class Stream(BaseStream):
def __init__(
self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
) -> None:
self,
stream_reader: asyncio.StreamReader,
stream_writer: asyncio.StreamWriter,
timeout: TimeoutConfig,
):
self.stream_reader = stream_reader
self.stream_writer = stream_writer
self.timeout = timeout
def get_http_version(self) -> str:
ssl_object = self.stream_writer.get_extra_info("ssl_object")
if ssl_object is None:
return "HTTP/1.1"
ident = ssl_object.selected_alpn_protocol()
if ident is None:
ident = ssl_object.selected_npn_protocol()
return "HTTP/2" if ident == "h2" else "HTTP/1.1"
async def read(
self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
) -> bytes:
@ -78,15 +93,6 @@ class Reader(BaseReader):
return data
def is_connection_dropped(self) -> bool:
return self.stream_reader.at_eof()
class Writer(BaseWriter):
def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
self.stream_writer = stream_writer
self.timeout = timeout
def write_no_block(self, data: bytes) -> None:
self.stream_writer.write(data) # pragma: nocover
@ -114,6 +120,9 @@ class Writer(BaseWriter):
if should_raise:
raise WriteTimeout() from None
def is_connection_dropped(self) -> bool:
return self.stream_reader.at_eof()
async def close(self) -> None:
self.stream_writer.close()
@ -172,7 +181,7 @@ class AsyncioBackend(ConcurrencyBackend):
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> typing.Tuple[BaseReader, BaseWriter, str]:
) -> BaseStream:
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
@ -181,19 +190,9 @@ class AsyncioBackend(ConcurrencyBackend):
except asyncio.TimeoutError:
raise ConnectTimeout()
ssl_object = stream_writer.get_extra_info("ssl_object")
if ssl_object is None:
ident = "http/1.1"
else:
ident = ssl_object.selected_alpn_protocol()
if ident is None:
ident = ssl_object.selected_npn_protocol()
reader = Reader(stream_reader=stream_reader, timeout=timeout)
writer = Writer(stream_writer=stream_writer, timeout=timeout)
http_version = "HTTP/2" if ident == "h2" else "HTTP/1.1"
return reader, writer, http_version
return Stream(
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any

View File

@ -37,29 +37,21 @@ class TimeoutFlag:
self.raise_on_write_timeout = True
class BaseReader:
class BaseStream:
"""
A stream reader. Abstracts away any asyncio-specific interfaces
into a more generic base class, that we can use with alternate
backend, or for stand-alone test cases.
A stream with read/write operations. Abstracts away any asyncio-specific
interfaces into a more generic base class, that we can use with alternate
backends, or for stand-alone test cases.
"""
def get_http_version(self) -> str:
raise NotImplementedError() # pragma: no cover
async def read(
self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None
) -> bytes:
raise NotImplementedError() # pragma: no cover
def is_connection_dropped(self) -> bool:
raise NotImplementedError() # pragma: no cover
class BaseWriter:
"""
A stream writer. Abstracts away any asyncio-specific interfaces
into a more generic base class, that we can use with alternate
backend, or for stand-alone test cases.
"""
def write_no_block(self, data: bytes) -> None:
raise NotImplementedError() # pragma: no cover
@ -69,6 +61,9 @@ class BaseWriter:
async def close(self) -> None:
raise NotImplementedError() # pragma: no cover
def is_connection_dropped(self) -> bool:
raise NotImplementedError() # pragma: no cover
class BaseQueue:
"""
@ -118,7 +113,7 @@ class ConcurrencyBackend:
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> typing.Tuple[BaseReader, BaseWriter, str]:
) -> BaseStream:
raise NotImplementedError() # pragma: no cover
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:

View File

@ -79,24 +79,24 @@ class HTTPConnection(AsyncDispatcher):
else:
on_release = functools.partial(self.release_func, self)
reader, writer, http_version = await self.backend.connect(
host, port, ssl_context, timeout
)
stream = await self.backend.connect(host, port, ssl_context, timeout)
http_version = stream.get_http_version()
if http_version == "HTTP/2":
self.h2_connection = HTTP2Connection(
reader, writer, self.backend, on_release=on_release
stream, self.backend, on_release=on_release
)
else:
assert http_version == "HTTP/1.1"
self.h11_connection = HTTP11Connection(
reader, writer, self.backend, on_release=on_release
stream, self.backend, on_release=on_release
)
async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
if not self.origin.is_ssl:
return None
# Run the SSL loading in a threadpool, since it may makes disk accesses.
# Run the SSL loading in a threadpool, since it may make disk accesses.
return await self.backend.run_in_threadpool(
ssl.load_ssl_context, self.http_versions
)

View File

@ -2,7 +2,7 @@ import typing
import h11
from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag
from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..models import AsyncRequest, AsyncResponse
@ -27,13 +27,11 @@ class HTTP11Connection:
def __init__(
self,
reader: BaseReader,
writer: BaseWriter,
stream: BaseStream,
backend: ConcurrencyBackend,
on_release: typing.Optional[OnReleaseCallback] = None,
):
self.reader = reader
self.writer = writer
self.stream = stream
self.backend = backend
self.on_release = on_release
self.h11_state = h11.Connection(our_role=h11.CLIENT)
@ -67,7 +65,7 @@ class HTTP11Connection:
except h11.LocalProtocolError: # pragma: no cover
# Premature client disconnect
pass
await self.writer.close()
await self.stream.close()
async def _send_request(
self, request: AsyncRequest, timeout: TimeoutConfig = None
@ -111,7 +109,7 @@ class HTTP11Connection:
drain before returning.
"""
bytes_to_send = self.h11_state.send(event)
await self.writer.write(bytes_to_send, timeout)
await self.stream.write(bytes_to_send, timeout)
async def _receive_response(
self, timeout: TimeoutConfig = None
@ -154,7 +152,7 @@ class HTTP11Connection:
event = self.h11_state.next_event()
if event is h11.NEED_DATA:
try:
data = await self.reader.read(
data = await self.stream.read(
self.READ_NUM_BYTES, timeout, flag=self.timeout_flag
)
except OSError: # pragma: nocover
@ -184,4 +182,4 @@ class HTTP11Connection:
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
def is_connection_dropped(self) -> bool:
return self.reader.is_connection_dropped()
return self.stream.is_connection_dropped()

View File

@ -4,7 +4,7 @@ import typing
import h2.connection
import h2.events
from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag
from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..models import AsyncRequest, AsyncResponse
@ -14,13 +14,11 @@ class HTTP2Connection:
def __init__(
self,
reader: BaseReader,
writer: BaseWriter,
stream: BaseStream,
backend: ConcurrencyBackend,
on_release: typing.Callable = None,
):
self.reader = reader
self.writer = writer
self.stream = stream
self.backend = backend
self.on_release = on_release
self.h2_state = h2.connection.H2Connection()
@ -58,12 +56,12 @@ class HTTP2Connection:
)
async def close(self) -> None:
await self.writer.close()
await self.stream.close()
def initiate_connection(self) -> None:
self.h2_state.initiate_connection()
data_to_send = self.h2_state.data_to_send()
self.writer.write_no_block(data_to_send)
self.stream.write_no_block(data_to_send)
self.initialized = True
async def send_headers(
@ -78,7 +76,7 @@ class HTTP2Connection:
] + [(k, v) for k, v in request.headers.raw if k != b"host"]
self.h2_state.send_headers(stream_id, headers)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
await self.stream.write(data_to_send, timeout)
return stream_id
async def send_request_data(
@ -104,12 +102,12 @@ class HTTP2Connection:
chunk = data[idx : idx + chunk_size]
self.h2_state.send_data(stream_id, chunk)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
await self.stream.write(data_to_send, timeout)
async def end_stream(self, stream_id: int, timeout: TimeoutConfig = None) -> None:
self.h2_state.end_stream(stream_id)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
await self.stream.write(data_to_send, timeout)
async def receive_response(
self, stream_id: int, timeout: TimeoutConfig = None
@ -150,14 +148,14 @@ class HTTP2Connection:
) -> h2.events.Event:
while not self.events[stream_id]:
flag = self.timeout_flags[stream_id]
data = await self.reader.read(self.READ_NUM_BYTES, timeout, flag=flag)
data = await self.stream.read(self.READ_NUM_BYTES, timeout, flag=flag)
events = self.h2_state.receive_data(data)
for event in events:
if getattr(event, "stream_id", 0):
self.events[event.stream_id].append(event)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
await self.stream.write(data_to_send, timeout)
return self.events[stream_id].pop(0)
@ -173,4 +171,4 @@ class HTTP2Connection:
return False
def is_connection_dropped(self) -> bool:
return self.reader.is_connection_dropped()
return self.stream.is_connection_dropped()

View File

@ -6,7 +6,7 @@ import h2.config
import h2.connection
import h2.events
from httpx import AsyncioBackend, BaseReader, BaseWriter, Request, TimeoutConfig
from httpx import AsyncioBackend, BaseStream, Request, TimeoutConfig
class MockHTTP2Backend(AsyncioBackend):
@ -20,16 +20,12 @@ class MockHTTP2Backend(AsyncioBackend):
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> typing.Tuple[BaseReader, BaseWriter, str]:
) -> BaseStream:
self.server = MockHTTP2Server(self.app)
return self.server, self.server, "HTTP/2"
return self.server
class MockHTTP2Server(BaseReader, BaseWriter):
"""
This class exposes Reader and Writer style interfaces.
"""
class MockHTTP2Server(BaseStream):
def __init__(self, app):
config = h2.config.H2Configuration(client_side=False)
self.conn = h2.connection.H2Connection(config=config)
@ -38,15 +34,16 @@ class MockHTTP2Server(BaseReader, BaseWriter):
self.requests = {}
self.close_connection = False
# BaseReader interface
# Stream interface
def get_http_version(self) -> str:
return "HTTP/2"
async def read(self, n, timeout, flag=None) -> bytes:
await asyncio.sleep(0)
send, self.buffer = self.buffer[:n], self.buffer[n:]
return send
# BaseWriter interface
def write_no_block(self, data: bytes) -> None:
events = self.conn.receive_data(data)
self.buffer += self.conn.data_to_send()