Add ConcurrencyBackend.start_tls() (#263)
This commit is contained in:
parent
a4b93b91c0
commit
1872ae873b
@ -18,8 +18,8 @@ from ..config import PoolLimits, TimeoutConfig
|
||||
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
|
||||
from .base import (
|
||||
BaseBackgroundManager,
|
||||
BasePoolSemaphore,
|
||||
BaseEvent,
|
||||
BasePoolSemaphore,
|
||||
BaseQueue,
|
||||
BaseStream,
|
||||
ConcurrencyBackend,
|
||||
@ -194,6 +194,44 @@ class AsyncioBackend(ConcurrencyBackend):
|
||||
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
|
||||
)
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
stream: BaseStream,
|
||||
hostname: str,
|
||||
ssl_context: ssl.SSLContext,
|
||||
timeout: TimeoutConfig,
|
||||
) -> BaseStream:
|
||||
|
||||
loop = self.loop
|
||||
if not hasattr(loop, "start_tls"): # pragma: no cover
|
||||
raise NotImplementedError(
|
||||
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
|
||||
)
|
||||
|
||||
assert isinstance(stream, Stream)
|
||||
|
||||
stream_reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(stream_reader)
|
||||
transport = stream.stream_writer.transport
|
||||
|
||||
loop_start_tls = loop.start_tls # type: ignore
|
||||
transport = await asyncio.wait_for(
|
||||
loop_start_tls(
|
||||
transport=transport,
|
||||
protocol=protocol,
|
||||
sslcontext=ssl_context,
|
||||
server_hostname=hostname,
|
||||
),
|
||||
timeout=timeout.connect_timeout,
|
||||
)
|
||||
|
||||
stream_reader.set_transport(transport)
|
||||
stream.stream_reader = stream_reader
|
||||
stream.stream_writer = asyncio.StreamWriter(
|
||||
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
|
||||
)
|
||||
return stream
|
||||
|
||||
async def run_in_threadpool(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
|
||||
@ -116,6 +116,15 @@ class ConcurrencyBackend:
|
||||
) -> BaseStream:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
stream: BaseStream,
|
||||
hostname: str,
|
||||
ssl_context: ssl.SSLContext,
|
||||
timeout: TimeoutConfig,
|
||||
) -> BaseStream:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import typing
|
||||
|
||||
from .base import AsyncDispatcher
|
||||
from ..concurrency.base import ConcurrencyBackend
|
||||
from ..concurrency.asyncio import AsyncioBackend
|
||||
from ..concurrency.base import ConcurrencyBackend
|
||||
from ..config import CertTypes, TimeoutTypes, VerifyTypes
|
||||
from ..models import AsyncRequest, AsyncResponse
|
||||
from .base import AsyncDispatcher
|
||||
|
||||
|
||||
class ASGIDispatch(AsyncDispatcher):
|
||||
|
||||
@ -2,7 +2,6 @@ import functools
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
from .base import AsyncDispatcher
|
||||
from ..concurrency.asyncio import AsyncioBackend
|
||||
from ..concurrency.base import ConcurrencyBackend
|
||||
from ..config import (
|
||||
@ -16,6 +15,7 @@ from ..config import (
|
||||
VerifyTypes,
|
||||
)
|
||||
from ..models import AsyncRequest, AsyncResponse, Origin
|
||||
from .base import AsyncDispatcher
|
||||
from .http2 import HTTP2Connection
|
||||
from .http11 import HTTP11Connection
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import typing
|
||||
|
||||
from .base import AsyncDispatcher
|
||||
from ..concurrency.asyncio import AsyncioBackend
|
||||
from ..concurrency.base import ConcurrencyBackend
|
||||
from ..config import (
|
||||
@ -13,6 +12,7 @@ from ..config import (
|
||||
VerifyTypes,
|
||||
)
|
||||
from ..models import AsyncRequest, AsyncResponse, Origin
|
||||
from .base import AsyncDispatcher
|
||||
from .connection import HTTPConnection
|
||||
|
||||
CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from .base import AsyncDispatcher, Dispatcher
|
||||
from ..concurrency.base import ConcurrencyBackend
|
||||
from ..config import CertTypes, TimeoutTypes, VerifyTypes
|
||||
from ..models import (
|
||||
@ -11,6 +10,7 @@ from ..models import (
|
||||
Response,
|
||||
ResponseContent,
|
||||
)
|
||||
from .base import AsyncDispatcher, Dispatcher
|
||||
|
||||
|
||||
class ThreadedDispatcher(AsyncDispatcher):
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import io
|
||||
import typing
|
||||
|
||||
from .base import Dispatcher
|
||||
from ..config import CertTypes, TimeoutTypes, VerifyTypes
|
||||
from ..models import Request, Response
|
||||
from .base import Dispatcher
|
||||
|
||||
|
||||
class WSGIDispatch(Dispatcher):
|
||||
|
||||
31
tests/test_concurrency.py
Normal file
31
tests/test_concurrency.py
Normal file
@ -0,0 +1,31 @@
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from httpx import AsyncioBackend, HTTPVersionConfig, SSLConfig, TimeoutConfig
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
sys.version_info < (3, 7),
|
||||
reason="Requires Python 3.7+ for AbstractEventLoop.start_tls()",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_tls_on_socket_stream(https_server):
|
||||
"""
|
||||
See that the backend can make a connection without TLS then
|
||||
start TLS on an existing connection.
|
||||
"""
|
||||
backend = AsyncioBackend()
|
||||
ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
|
||||
timeout = TimeoutConfig(5)
|
||||
|
||||
stream = await backend.connect("127.0.0.1", 8001, None, timeout)
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert stream.stream_writer.get_extra_info("cipher", default=None) is None
|
||||
|
||||
stream = await backend.start_tls(stream, "127.0.0.1", ctx, timeout)
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
|
||||
|
||||
await stream.write(b"GET / HTTP/1.1\r\n\r\n")
|
||||
assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")
|
||||
Loading…
Reference in New Issue
Block a user