Make start_tls a method on streams & return a new stream (#484)

* Move start_tls to stream & return a new stream

* asyncio: Keep a reference to the inner stream when upgrading to TLS
This commit is contained in:
Jamie Hewland 2019-10-20 12:59:16 +02:00 committed by Florimond Manca
parent ad38db82f9
commit 644e8fc5b6
6 changed files with 72 additions and 87 deletions

View File

@ -51,6 +51,44 @@ class TCPStream(BaseTCPStream):
self.stream_writer = stream_writer
self.timeout = timeout
self._inner: typing.Optional[TCPStream] = None
async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
) -> BaseTCPStream:
loop = asyncio.get_event_loop()
if not hasattr(loop, "start_tls"): # pragma: no cover
raise NotImplementedError(
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
)
stream_reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(stream_reader)
transport = self.stream_writer.transport
loop_start_tls = loop.start_tls # type: ignore
transport = await asyncio.wait_for(
loop_start_tls(
transport=transport,
protocol=protocol,
sslcontext=ssl_context,
server_hostname=hostname,
),
timeout=timeout.connect_timeout,
)
stream_reader.set_transport(transport)
stream_writer = asyncio.StreamWriter(
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
)
ssl_stream = TCPStream(stream_reader, stream_writer, self.timeout)
# When we return a new TCPStream with new StreamReader/StreamWriter instances,
# we need to keep references to the old StreamReader/StreamWriter so that they
# are not garbage collected and closed while we're still using them.
ssl_stream._inner = self
return ssl_stream
def get_http_version(self) -> str:
ssl_object = self.stream_writer.get_extra_info("ssl_object")
@ -201,44 +239,6 @@ class AsyncioBackend(ConcurrencyBackend):
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)
async def start_tls(
self,
stream: BaseTCPStream,
hostname: str,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseTCPStream:
loop = self.loop
if not hasattr(loop, "start_tls"): # pragma: no cover
raise NotImplementedError(
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
)
assert isinstance(stream, TCPStream)
stream_reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(stream_reader)
transport = stream.stream_writer.transport
loop_start_tls = loop.start_tls # type: ignore
transport = await asyncio.wait_for(
loop_start_tls(
transport=transport,
protocol=protocol,
sslcontext=ssl_context,
server_hostname=hostname,
),
timeout=timeout.connect_timeout,
)
stream_reader.set_transport(transport)
stream.stream_reader = stream_reader
stream.stream_writer = asyncio.StreamWriter(
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
)
return stream
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:

View File

@ -47,6 +47,11 @@ class BaseTCPStream:
def get_http_version(self) -> str:
raise NotImplementedError() # pragma: no cover
async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
) -> "BaseTCPStream":
raise NotImplementedError() # pragma: no cover
async def read(
self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None
) -> bytes:
@ -119,15 +124,6 @@ class ConcurrencyBackend:
) -> BaseTCPStream:
raise NotImplementedError() # pragma: no cover
async def start_tls(
self,
stream: BaseTCPStream,
hostname: str,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseTCPStream:
raise NotImplementedError() # pragma: no cover
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
raise NotImplementedError() # pragma: no cover

View File

@ -34,6 +34,26 @@ class TCPStream(BaseTCPStream):
self.write_buffer = b""
self.write_lock = trio.Lock()
async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
) -> BaseTCPStream:
# Check that the write buffer is empty. We should never start a TLS stream
# while there is still pending data to write.
assert self.write_buffer == b""
connect_timeout = _or_inf(timeout.connect_timeout)
ssl_stream = trio.SSLStream(
self.stream, ssl_context=ssl_context, server_hostname=hostname
)
with trio.move_on_after(connect_timeout) as cancel_scope:
await ssl_stream.do_handshake()
if cancel_scope.cancelled_caught:
raise ConnectTimeout()
return TCPStream(ssl_stream, self.timeout)
def get_http_version(self) -> str:
if not isinstance(self.stream, trio.SSLStream):
return "HTTP/1.1"
@ -171,30 +191,6 @@ class TrioBackend(ConcurrencyBackend):
return TCPStream(stream=stream, timeout=timeout)
async def start_tls(
self,
stream: BaseTCPStream,
hostname: str,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseTCPStream:
assert isinstance(stream, TCPStream)
connect_timeout = _or_inf(timeout.connect_timeout)
ssl_stream = trio.SSLStream(
stream.stream, ssl_context=ssl_context, server_hostname=hostname
)
with trio.move_on_after(connect_timeout) as cancel_scope:
await ssl_stream.do_handshake()
if cancel_scope.cancelled_caught:
raise ConnectTimeout()
stream.stream = ssl_stream
return stream
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:

View File

@ -192,11 +192,8 @@ class HTTPProxy(ConnectionPool):
f"proxy_url={self.proxy_url!r} "
f"origin={origin!r}"
)
stream = await self.backend.start_tls(
stream=stream,
hostname=origin.host,
ssl_context=ssl_context,
timeout=timeout,
stream = await stream.start_tls(
hostname=origin.host, ssl_context=ssl_context, timeout=timeout
)
http_version = stream.get_http_version()
logger.debug(

View File

@ -184,16 +184,6 @@ class MockRawSocketBackend:
)
return self.stream
async def start_tls(
self,
stream: BaseTCPStream,
hostname: str,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseTCPStream:
self.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
return self.stream
# Defer all other attributes and methods to the underlying backend.
def __getattr__(self, name: str) -> typing.Any:
return getattr(self.backend, name)
@ -203,6 +193,12 @@ class MockRawSocketStream(BaseTCPStream):
def __init__(self, backend: MockRawSocketBackend):
self.backend = backend
async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
) -> BaseTCPStream:
self.backend.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
return MockRawSocketStream(self.backend)
def get_http_version(self) -> str:
return "HTTP/1.1"

View File

@ -45,7 +45,7 @@ async def test_start_tls_on_socket_stream(https_server, backend, get_cipher):
assert stream.is_connection_dropped() is False
assert get_cipher(stream) is None
stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout)
stream = await stream.start_tls(https_server.url.host, ctx, timeout)
assert stream.is_connection_dropped() is False
assert get_cipher(stream) is not None