Add ConcurrencyBackend.start_tls() (#263)

This commit is contained in:
Seth Michael Larson 2019-08-24 10:04:14 -05:00 committed by GitHub
parent a4b93b91c0
commit 1872ae873b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 85 additions and 7 deletions

View File

@ -18,8 +18,8 @@ from ..config import PoolLimits, TimeoutConfig
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
BaseBackgroundManager,
BasePoolSemaphore,
BaseEvent,
BasePoolSemaphore,
BaseQueue,
BaseStream,
ConcurrencyBackend,
@ -194,6 +194,44 @@ class AsyncioBackend(ConcurrencyBackend):
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)
async def start_tls(
self,
stream: BaseStream,
hostname: str,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseStream:
loop = self.loop
if not hasattr(loop, "start_tls"): # pragma: no cover
raise NotImplementedError(
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
)
assert isinstance(stream, Stream)
stream_reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(stream_reader)
transport = stream.stream_writer.transport
loop_start_tls = loop.start_tls # type: ignore
transport = await asyncio.wait_for(
loop_start_tls(
transport=transport,
protocol=protocol,
sslcontext=ssl_context,
server_hostname=hostname,
),
timeout=timeout.connect_timeout,
)
stream_reader.set_transport(transport)
stream.stream_reader = stream_reader
stream.stream_writer = asyncio.StreamWriter(
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
)
return stream
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:

View File

@ -116,6 +116,15 @@ class ConcurrencyBackend:
) -> BaseStream:
raise NotImplementedError() # pragma: no cover
async def start_tls(
self,
stream: BaseStream,
hostname: str,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseStream:
raise NotImplementedError() # pragma: no cover
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
raise NotImplementedError() # pragma: no cover

View File

@ -1,10 +1,10 @@
import typing
from .base import AsyncDispatcher
from ..concurrency.base import ConcurrencyBackend
from ..concurrency.asyncio import AsyncioBackend
from ..concurrency.base import ConcurrencyBackend
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import AsyncRequest, AsyncResponse
from .base import AsyncDispatcher
class ASGIDispatch(AsyncDispatcher):

View File

@ -2,7 +2,6 @@ import functools
import ssl
import typing
from .base import AsyncDispatcher
from ..concurrency.asyncio import AsyncioBackend
from ..concurrency.base import ConcurrencyBackend
from ..config import (
@ -16,6 +15,7 @@ from ..config import (
VerifyTypes,
)
from ..models import AsyncRequest, AsyncResponse, Origin
from .base import AsyncDispatcher
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection

View File

@ -1,6 +1,5 @@
import typing
from .base import AsyncDispatcher
from ..concurrency.asyncio import AsyncioBackend
from ..concurrency.base import ConcurrencyBackend
from ..config import (
@ -13,6 +12,7 @@ from ..config import (
VerifyTypes,
)
from ..models import AsyncRequest, AsyncResponse, Origin
from .base import AsyncDispatcher
from .connection import HTTPConnection
CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]

View File

@ -1,4 +1,3 @@
from .base import AsyncDispatcher, Dispatcher
from ..concurrency.base import ConcurrencyBackend
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import (
@ -11,6 +10,7 @@ from ..models import (
Response,
ResponseContent,
)
from .base import AsyncDispatcher, Dispatcher
class ThreadedDispatcher(AsyncDispatcher):

View File

@ -1,9 +1,9 @@
import io
import typing
from .base import Dispatcher
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import Request, Response
from .base import Dispatcher
class WSGIDispatch(Dispatcher):

31
tests/test_concurrency.py Normal file
View File

@ -0,0 +1,31 @@
import sys
import pytest
from httpx import AsyncioBackend, HTTPVersionConfig, SSLConfig, TimeoutConfig
@pytest.mark.xfail(
sys.version_info < (3, 7),
reason="Requires Python 3.7+ for AbstractEventLoop.start_tls()",
)
@pytest.mark.asyncio
async def test_start_tls_on_socket_stream(https_server):
"""
See that the backend can make a connection without TLS then
start TLS on an existing connection.
"""
backend = AsyncioBackend()
ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
timeout = TimeoutConfig(5)
stream = await backend.connect("127.0.0.1", 8001, None, timeout)
assert stream.is_connection_dropped() is False
assert stream.stream_writer.get_extra_info("cipher", default=None) is None
stream = await backend.start_tls(stream, "127.0.0.1", ctx, timeout)
assert stream.is_connection_dropped() is False
assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
await stream.write(b"GET / HTTP/1.1\r\n\r\n")
assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")