* Refactoring h11 implementation * More h11 refactoring * Support early connection closes on H11 connections * Tweak comment * Refactor concurrent read/writes * Drop WriteTimeout masking * Linting * Use concurrent read/writes for HTTP2 * Push background sending into ConcurrencyBackend
224 lines
6.8 KiB
Python
224 lines
6.8 KiB
Python
"""
|
|
The `Reader` and `Writer` classes here provide a lightweight layer over
|
|
`asyncio.StreamReader` and `asyncio.StreamWriter`.
|
|
|
|
Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`.
|
|
|
|
These classes help encapsulate the timeout logic, make it easier to unit-test
|
|
protocols, and help keep the rest of the package more `async`/`await`
|
|
based, and less strictly `asyncio`-specific.
|
|
"""
|
|
import asyncio
|
|
import functools
|
|
import ssl
|
|
import typing
|
|
from types import TracebackType
|
|
|
|
from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
|
|
from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
|
|
from .interfaces import (
|
|
BaseBackgroundManager,
|
|
BasePoolSemaphore,
|
|
BaseReader,
|
|
BaseWriter,
|
|
ConcurrencyBackend,
|
|
Protocol,
|
|
)
|
|
|
|
SSL_MONKEY_PATCH_APPLIED = False
|
|
|
|
|
|
def ssl_monkey_patch() -> None:
|
|
"""
|
|
Monky-patch for https://bugs.python.org/issue36709
|
|
|
|
This prevents console errors when outstanding HTTPS connections
|
|
still exist at the point of exiting.
|
|
|
|
Clients which have been opened using a `with` block, or which have
|
|
had `close()` closed, will not exhibit this issue in the first place.
|
|
"""
|
|
MonkeyPatch = asyncio.selector_events._SelectorSocketTransport # type: ignore
|
|
|
|
_write = MonkeyPatch.write
|
|
|
|
def _fixed_write(self, data: bytes) -> None: # type: ignore
|
|
if self._loop and not self._loop.is_closed():
|
|
_write(self, data)
|
|
|
|
MonkeyPatch.write = _fixed_write
|
|
|
|
|
|
class Reader(BaseReader):
|
|
def __init__(
|
|
self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
|
|
) -> None:
|
|
self.stream_reader = stream_reader
|
|
self.timeout = timeout
|
|
|
|
async def read(self, n: int, timeout: TimeoutConfig = None) -> bytes:
|
|
if timeout is None:
|
|
timeout = self.timeout
|
|
|
|
try:
|
|
data = await asyncio.wait_for(
|
|
self.stream_reader.read(n), timeout.read_timeout
|
|
)
|
|
except asyncio.TimeoutError:
|
|
raise ReadTimeout()
|
|
|
|
return data
|
|
|
|
|
|
class Writer(BaseWriter):
|
|
def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
|
|
self.stream_writer = stream_writer
|
|
self.timeout = timeout
|
|
|
|
def write_no_block(self, data: bytes) -> None:
|
|
self.stream_writer.write(data) # pragma: nocover
|
|
|
|
async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None:
|
|
if not data:
|
|
return
|
|
|
|
if timeout is None:
|
|
timeout = self.timeout
|
|
|
|
self.stream_writer.write(data)
|
|
try:
|
|
await asyncio.wait_for( # type: ignore
|
|
self.stream_writer.drain(), timeout.write_timeout
|
|
)
|
|
except asyncio.TimeoutError:
|
|
raise WriteTimeout()
|
|
|
|
async def close(self) -> None:
|
|
self.stream_writer.close()
|
|
|
|
|
|
class PoolSemaphore(BasePoolSemaphore):
|
|
def __init__(self, pool_limits: PoolLimits):
|
|
self.pool_limits = pool_limits
|
|
|
|
@property
|
|
def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
|
|
if not hasattr(self, "_semaphore"):
|
|
max_connections = self.pool_limits.hard_limit
|
|
if max_connections is None:
|
|
self._semaphore = None
|
|
else:
|
|
self._semaphore = asyncio.BoundedSemaphore(value=max_connections)
|
|
return self._semaphore
|
|
|
|
async def acquire(self) -> None:
|
|
if self.semaphore is None:
|
|
return
|
|
|
|
timeout = self.pool_limits.pool_timeout
|
|
try:
|
|
await asyncio.wait_for(self.semaphore.acquire(), timeout)
|
|
except asyncio.TimeoutError:
|
|
raise PoolTimeout()
|
|
|
|
def release(self) -> None:
|
|
if self.semaphore is None:
|
|
return
|
|
|
|
self.semaphore.release()
|
|
|
|
|
|
class AsyncioBackend(ConcurrencyBackend):
|
|
def __init__(self) -> None:
|
|
global SSL_MONKEY_PATCH_APPLIED
|
|
|
|
if not SSL_MONKEY_PATCH_APPLIED:
|
|
ssl_monkey_patch()
|
|
SSL_MONKEY_PATCH_APPLIED = True
|
|
|
|
@property
|
|
def loop(self) -> asyncio.AbstractEventLoop:
|
|
if not hasattr(self, "_loop"):
|
|
try:
|
|
self._loop = asyncio.get_event_loop()
|
|
except RuntimeError:
|
|
self._loop = asyncio.new_event_loop()
|
|
return self._loop
|
|
|
|
async def connect(
|
|
self,
|
|
hostname: str,
|
|
port: int,
|
|
ssl_context: typing.Optional[ssl.SSLContext],
|
|
timeout: TimeoutConfig,
|
|
) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
|
|
try:
|
|
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
|
|
asyncio.open_connection(hostname, port, ssl=ssl_context),
|
|
timeout.connect_timeout,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
raise ConnectTimeout()
|
|
|
|
ssl_object = stream_writer.get_extra_info("ssl_object")
|
|
if ssl_object is None:
|
|
ident = "http/1.1"
|
|
else:
|
|
ident = ssl_object.selected_alpn_protocol()
|
|
if ident is None:
|
|
ident = ssl_object.selected_npn_protocol()
|
|
|
|
reader = Reader(stream_reader=stream_reader, timeout=timeout)
|
|
writer = Writer(stream_writer=stream_writer, timeout=timeout)
|
|
protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11
|
|
|
|
return (reader, writer, protocol)
|
|
|
|
async def run_in_threadpool(
|
|
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
|
) -> typing.Any:
|
|
if kwargs:
|
|
# loop.run_in_executor doesn't accept 'kwargs', so bind them in here
|
|
func = functools.partial(func, **kwargs)
|
|
return await self.loop.run_in_executor(None, func, *args)
|
|
|
|
def run(
|
|
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
|
) -> typing.Any:
|
|
loop = self.loop
|
|
if loop.is_running():
|
|
self._loop = asyncio.new_event_loop()
|
|
try:
|
|
return self.loop.run_until_complete(coroutine(*args, **kwargs))
|
|
finally:
|
|
self._loop = loop
|
|
|
|
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
|
|
return PoolSemaphore(limits)
|
|
|
|
def background_manager(
|
|
self, coroutine: typing.Callable, args: typing.Any
|
|
) -> "BackgroundManager":
|
|
return BackgroundManager(coroutine, args)
|
|
|
|
|
|
class BackgroundManager(BaseBackgroundManager):
|
|
def __init__(self, coroutine: typing.Callable, args: typing.Any) -> None:
|
|
self.coroutine = coroutine
|
|
self.args = args
|
|
|
|
async def __aenter__(self) -> "BackgroundManager":
|
|
loop = asyncio.get_event_loop()
|
|
self.task = loop.create_task(self.coroutine(*self.args))
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: typing.Type[BaseException] = None,
|
|
exc_value: BaseException = None,
|
|
traceback: TracebackType = None,
|
|
) -> None:
|
|
await self.task
|
|
if exc_type is None:
|
|
self.task.result()
|