Merge pull request #255 from encode/refactor/stream-interface
Unify BaseReader and BaseWriter as BaseStream
This commit is contained in:
commit
30692272fd
@ -5,8 +5,7 @@ from .concurrency.asyncio import AsyncioBackend
|
||||
from .concurrency.base import (
|
||||
BaseBackgroundManager,
|
||||
BasePoolSemaphore,
|
||||
BaseReader,
|
||||
BaseWriter,
|
||||
BaseStream,
|
||||
ConcurrencyBackend,
|
||||
)
|
||||
from .config import (
|
||||
@ -105,8 +104,7 @@ __all__ = [
|
||||
"TooManyRedirects",
|
||||
"WriteTimeout",
|
||||
"AsyncDispatcher",
|
||||
"BaseReader",
|
||||
"BaseWriter",
|
||||
"BaseStream",
|
||||
"ConcurrencyBackend",
|
||||
"Dispatcher",
|
||||
"URL",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
The `Reader` and `Writer` classes here provide a lightweight layer over
|
||||
The `Stream` class here provides a lightweight layer over
|
||||
`asyncio.StreamReader` and `asyncio.StreamWriter`.
|
||||
|
||||
Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`.
|
||||
@ -14,18 +14,17 @@ import ssl
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from ..config import PoolLimits, TimeoutConfig
|
||||
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
|
||||
from .base import (
|
||||
BaseBackgroundManager,
|
||||
BasePoolSemaphore,
|
||||
BaseEvent,
|
||||
BaseQueue,
|
||||
BaseReader,
|
||||
BaseWriter,
|
||||
BaseStream,
|
||||
ConcurrencyBackend,
|
||||
TimeoutFlag,
|
||||
)
|
||||
from ..config import PoolLimits, TimeoutConfig
|
||||
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
|
||||
|
||||
SSL_MONKEY_PATCH_APPLIED = False
|
||||
|
||||
@ -51,13 +50,29 @@ def ssl_monkey_patch() -> None:
|
||||
MonkeyPatch.write = _fixed_write
|
||||
|
||||
|
||||
class Reader(BaseReader):
|
||||
class Stream(BaseStream):
|
||||
def __init__(
|
||||
self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
|
||||
) -> None:
|
||||
self,
|
||||
stream_reader: asyncio.StreamReader,
|
||||
stream_writer: asyncio.StreamWriter,
|
||||
timeout: TimeoutConfig,
|
||||
):
|
||||
self.stream_reader = stream_reader
|
||||
self.stream_writer = stream_writer
|
||||
self.timeout = timeout
|
||||
|
||||
def get_http_version(self) -> str:
|
||||
ssl_object = self.stream_writer.get_extra_info("ssl_object")
|
||||
|
||||
if ssl_object is None:
|
||||
return "HTTP/1.1"
|
||||
|
||||
ident = ssl_object.selected_alpn_protocol()
|
||||
if ident is None:
|
||||
ident = ssl_object.selected_npn_protocol()
|
||||
|
||||
return "HTTP/2" if ident == "h2" else "HTTP/1.1"
|
||||
|
||||
async def read(
|
||||
self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
|
||||
) -> bytes:
|
||||
@ -78,15 +93,6 @@ class Reader(BaseReader):
|
||||
|
||||
return data
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
return self.stream_reader.at_eof()
|
||||
|
||||
|
||||
class Writer(BaseWriter):
|
||||
def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
|
||||
self.stream_writer = stream_writer
|
||||
self.timeout = timeout
|
||||
|
||||
def write_no_block(self, data: bytes) -> None:
|
||||
self.stream_writer.write(data) # pragma: nocover
|
||||
|
||||
@ -114,6 +120,9 @@ class Writer(BaseWriter):
|
||||
if should_raise:
|
||||
raise WriteTimeout() from None
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
return self.stream_reader.at_eof()
|
||||
|
||||
async def close(self) -> None:
|
||||
self.stream_writer.close()
|
||||
|
||||
@ -172,7 +181,7 @@ class AsyncioBackend(ConcurrencyBackend):
|
||||
port: int,
|
||||
ssl_context: typing.Optional[ssl.SSLContext],
|
||||
timeout: TimeoutConfig,
|
||||
) -> typing.Tuple[BaseReader, BaseWriter, str]:
|
||||
) -> BaseStream:
|
||||
try:
|
||||
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
|
||||
asyncio.open_connection(hostname, port, ssl=ssl_context),
|
||||
@ -181,19 +190,9 @@ class AsyncioBackend(ConcurrencyBackend):
|
||||
except asyncio.TimeoutError:
|
||||
raise ConnectTimeout()
|
||||
|
||||
ssl_object = stream_writer.get_extra_info("ssl_object")
|
||||
if ssl_object is None:
|
||||
ident = "http/1.1"
|
||||
else:
|
||||
ident = ssl_object.selected_alpn_protocol()
|
||||
if ident is None:
|
||||
ident = ssl_object.selected_npn_protocol()
|
||||
|
||||
reader = Reader(stream_reader=stream_reader, timeout=timeout)
|
||||
writer = Writer(stream_writer=stream_writer, timeout=timeout)
|
||||
http_version = "HTTP/2" if ident == "h2" else "HTTP/1.1"
|
||||
|
||||
return reader, writer, http_version
|
||||
return Stream(
|
||||
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
|
||||
)
|
||||
|
||||
async def run_in_threadpool(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
|
||||
@ -37,29 +37,21 @@ class TimeoutFlag:
|
||||
self.raise_on_write_timeout = True
|
||||
|
||||
|
||||
class BaseReader:
|
||||
class BaseStream:
|
||||
"""
|
||||
A stream reader. Abstracts away any asyncio-specific interfaces
|
||||
into a more generic base class, that we can use with alternate
|
||||
backend, or for stand-alone test cases.
|
||||
A stream with read/write operations. Abstracts away any asyncio-specific
|
||||
interfaces into a more generic base class, that we can use with alternate
|
||||
backends, or for stand-alone test cases.
|
||||
"""
|
||||
|
||||
def get_http_version(self) -> str:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def read(
|
||||
self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None
|
||||
) -> bytes:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class BaseWriter:
|
||||
"""
|
||||
A stream writer. Abstracts away any asyncio-specific interfaces
|
||||
into a more generic base class, that we can use with alternate
|
||||
backend, or for stand-alone test cases.
|
||||
"""
|
||||
|
||||
def write_no_block(self, data: bytes) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
@ -69,6 +61,9 @@ class BaseWriter:
|
||||
async def close(self) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class BaseQueue:
|
||||
"""
|
||||
@ -118,7 +113,7 @@ class ConcurrencyBackend:
|
||||
port: int,
|
||||
ssl_context: typing.Optional[ssl.SSLContext],
|
||||
timeout: TimeoutConfig,
|
||||
) -> typing.Tuple[BaseReader, BaseWriter, str]:
|
||||
) -> BaseStream:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
|
||||
|
||||
@ -79,24 +79,24 @@ class HTTPConnection(AsyncDispatcher):
|
||||
else:
|
||||
on_release = functools.partial(self.release_func, self)
|
||||
|
||||
reader, writer, http_version = await self.backend.connect(
|
||||
host, port, ssl_context, timeout
|
||||
)
|
||||
stream = await self.backend.connect(host, port, ssl_context, timeout)
|
||||
http_version = stream.get_http_version()
|
||||
|
||||
if http_version == "HTTP/2":
|
||||
self.h2_connection = HTTP2Connection(
|
||||
reader, writer, self.backend, on_release=on_release
|
||||
stream, self.backend, on_release=on_release
|
||||
)
|
||||
else:
|
||||
assert http_version == "HTTP/1.1"
|
||||
self.h11_connection = HTTP11Connection(
|
||||
reader, writer, self.backend, on_release=on_release
|
||||
stream, self.backend, on_release=on_release
|
||||
)
|
||||
|
||||
async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
|
||||
if not self.origin.is_ssl:
|
||||
return None
|
||||
|
||||
# Run the SSL loading in a threadpool, since it may makes disk accesses.
|
||||
# Run the SSL loading in a threadpool, since it may make disk accesses.
|
||||
return await self.backend.run_in_threadpool(
|
||||
ssl.load_ssl_context, self.http_versions
|
||||
)
|
||||
|
||||
@ -2,7 +2,7 @@ import typing
|
||||
|
||||
import h11
|
||||
|
||||
from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag
|
||||
from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
|
||||
from ..config import TimeoutConfig, TimeoutTypes
|
||||
from ..models import AsyncRequest, AsyncResponse
|
||||
|
||||
@ -27,13 +27,11 @@ class HTTP11Connection:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reader: BaseReader,
|
||||
writer: BaseWriter,
|
||||
stream: BaseStream,
|
||||
backend: ConcurrencyBackend,
|
||||
on_release: typing.Optional[OnReleaseCallback] = None,
|
||||
):
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self.stream = stream
|
||||
self.backend = backend
|
||||
self.on_release = on_release
|
||||
self.h11_state = h11.Connection(our_role=h11.CLIENT)
|
||||
@ -67,7 +65,7 @@ class HTTP11Connection:
|
||||
except h11.LocalProtocolError: # pragma: no cover
|
||||
# Premature client disconnect
|
||||
pass
|
||||
await self.writer.close()
|
||||
await self.stream.close()
|
||||
|
||||
async def _send_request(
|
||||
self, request: AsyncRequest, timeout: TimeoutConfig = None
|
||||
@ -111,7 +109,7 @@ class HTTP11Connection:
|
||||
drain before returning.
|
||||
"""
|
||||
bytes_to_send = self.h11_state.send(event)
|
||||
await self.writer.write(bytes_to_send, timeout)
|
||||
await self.stream.write(bytes_to_send, timeout)
|
||||
|
||||
async def _receive_response(
|
||||
self, timeout: TimeoutConfig = None
|
||||
@ -154,7 +152,7 @@ class HTTP11Connection:
|
||||
event = self.h11_state.next_event()
|
||||
if event is h11.NEED_DATA:
|
||||
try:
|
||||
data = await self.reader.read(
|
||||
data = await self.stream.read(
|
||||
self.READ_NUM_BYTES, timeout, flag=self.timeout_flag
|
||||
)
|
||||
except OSError: # pragma: nocover
|
||||
@ -184,4 +182,4 @@ class HTTP11Connection:
|
||||
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
return self.reader.is_connection_dropped()
|
||||
return self.stream.is_connection_dropped()
|
||||
|
||||
@ -4,7 +4,7 @@ import typing
|
||||
import h2.connection
|
||||
import h2.events
|
||||
|
||||
from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag
|
||||
from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
|
||||
from ..config import TimeoutConfig, TimeoutTypes
|
||||
from ..models import AsyncRequest, AsyncResponse
|
||||
|
||||
@ -14,13 +14,11 @@ class HTTP2Connection:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reader: BaseReader,
|
||||
writer: BaseWriter,
|
||||
stream: BaseStream,
|
||||
backend: ConcurrencyBackend,
|
||||
on_release: typing.Callable = None,
|
||||
):
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self.stream = stream
|
||||
self.backend = backend
|
||||
self.on_release = on_release
|
||||
self.h2_state = h2.connection.H2Connection()
|
||||
@ -58,12 +56,12 @@ class HTTP2Connection:
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.writer.close()
|
||||
await self.stream.close()
|
||||
|
||||
def initiate_connection(self) -> None:
|
||||
self.h2_state.initiate_connection()
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
self.writer.write_no_block(data_to_send)
|
||||
self.stream.write_no_block(data_to_send)
|
||||
self.initialized = True
|
||||
|
||||
async def send_headers(
|
||||
@ -78,7 +76,7 @@ class HTTP2Connection:
|
||||
] + [(k, v) for k, v in request.headers.raw if k != b"host"]
|
||||
self.h2_state.send_headers(stream_id, headers)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.writer.write(data_to_send, timeout)
|
||||
await self.stream.write(data_to_send, timeout)
|
||||
return stream_id
|
||||
|
||||
async def send_request_data(
|
||||
@ -104,12 +102,12 @@ class HTTP2Connection:
|
||||
chunk = data[idx : idx + chunk_size]
|
||||
self.h2_state.send_data(stream_id, chunk)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.writer.write(data_to_send, timeout)
|
||||
await self.stream.write(data_to_send, timeout)
|
||||
|
||||
async def end_stream(self, stream_id: int, timeout: TimeoutConfig = None) -> None:
|
||||
self.h2_state.end_stream(stream_id)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.writer.write(data_to_send, timeout)
|
||||
await self.stream.write(data_to_send, timeout)
|
||||
|
||||
async def receive_response(
|
||||
self, stream_id: int, timeout: TimeoutConfig = None
|
||||
@ -150,14 +148,14 @@ class HTTP2Connection:
|
||||
) -> h2.events.Event:
|
||||
while not self.events[stream_id]:
|
||||
flag = self.timeout_flags[stream_id]
|
||||
data = await self.reader.read(self.READ_NUM_BYTES, timeout, flag=flag)
|
||||
data = await self.stream.read(self.READ_NUM_BYTES, timeout, flag=flag)
|
||||
events = self.h2_state.receive_data(data)
|
||||
for event in events:
|
||||
if getattr(event, "stream_id", 0):
|
||||
self.events[event.stream_id].append(event)
|
||||
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.writer.write(data_to_send, timeout)
|
||||
await self.stream.write(data_to_send, timeout)
|
||||
|
||||
return self.events[stream_id].pop(0)
|
||||
|
||||
@ -173,4 +171,4 @@ class HTTP2Connection:
|
||||
return False
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
return self.reader.is_connection_dropped()
|
||||
return self.stream.is_connection_dropped()
|
||||
|
||||
@ -6,7 +6,7 @@ import h2.config
|
||||
import h2.connection
|
||||
import h2.events
|
||||
|
||||
from httpx import AsyncioBackend, BaseReader, BaseWriter, Request, TimeoutConfig
|
||||
from httpx import AsyncioBackend, BaseStream, Request, TimeoutConfig
|
||||
|
||||
|
||||
class MockHTTP2Backend(AsyncioBackend):
|
||||
@ -20,16 +20,12 @@ class MockHTTP2Backend(AsyncioBackend):
|
||||
port: int,
|
||||
ssl_context: typing.Optional[ssl.SSLContext],
|
||||
timeout: TimeoutConfig,
|
||||
) -> typing.Tuple[BaseReader, BaseWriter, str]:
|
||||
) -> BaseStream:
|
||||
self.server = MockHTTP2Server(self.app)
|
||||
return self.server, self.server, "HTTP/2"
|
||||
return self.server
|
||||
|
||||
|
||||
class MockHTTP2Server(BaseReader, BaseWriter):
|
||||
"""
|
||||
This class exposes Reader and Writer style interfaces.
|
||||
"""
|
||||
|
||||
class MockHTTP2Server(BaseStream):
|
||||
def __init__(self, app):
|
||||
config = h2.config.H2Configuration(client_side=False)
|
||||
self.conn = h2.connection.H2Connection(config=config)
|
||||
@ -38,15 +34,16 @@ class MockHTTP2Server(BaseReader, BaseWriter):
|
||||
self.requests = {}
|
||||
self.close_connection = False
|
||||
|
||||
# BaseReader interface
|
||||
# Stream interface
|
||||
|
||||
def get_http_version(self) -> str:
|
||||
return "HTTP/2"
|
||||
|
||||
async def read(self, n, timeout, flag=None) -> bytes:
|
||||
await asyncio.sleep(0)
|
||||
send, self.buffer = self.buffer[:n], self.buffer[n:]
|
||||
return send
|
||||
|
||||
# BaseWriter interface
|
||||
|
||||
def write_no_block(self, data: bytes) -> None:
|
||||
events = self.conn.receive_data(data)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user