Make start_tls a method on streams & return a new stream (#484)
* Move start_tls to stream & return a new stream * asyncio: Keep a reference to the inner stream when upgrading to TLS
This commit is contained in:
parent
ad38db82f9
commit
644e8fc5b6
@ -51,6 +51,44 @@ class TCPStream(BaseTCPStream):
|
||||
self.stream_writer = stream_writer
|
||||
self.timeout = timeout
|
||||
|
||||
self._inner: typing.Optional[TCPStream] = None
|
||||
|
||||
async def start_tls(
|
||||
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
|
||||
) -> BaseTCPStream:
|
||||
loop = asyncio.get_event_loop()
|
||||
if not hasattr(loop, "start_tls"): # pragma: no cover
|
||||
raise NotImplementedError(
|
||||
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
|
||||
)
|
||||
|
||||
stream_reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(stream_reader)
|
||||
transport = self.stream_writer.transport
|
||||
|
||||
loop_start_tls = loop.start_tls # type: ignore
|
||||
transport = await asyncio.wait_for(
|
||||
loop_start_tls(
|
||||
transport=transport,
|
||||
protocol=protocol,
|
||||
sslcontext=ssl_context,
|
||||
server_hostname=hostname,
|
||||
),
|
||||
timeout=timeout.connect_timeout,
|
||||
)
|
||||
|
||||
stream_reader.set_transport(transport)
|
||||
stream_writer = asyncio.StreamWriter(
|
||||
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
|
||||
)
|
||||
|
||||
ssl_stream = TCPStream(stream_reader, stream_writer, self.timeout)
|
||||
# When we return a new TCPStream with new StreamReader/StreamWriter instances,
|
||||
# we need to keep references to the old StreamReader/StreamWriter so that they
|
||||
# are not garbage collected and closed while we're still using them.
|
||||
ssl_stream._inner = self
|
||||
return ssl_stream
|
||||
|
||||
def get_http_version(self) -> str:
|
||||
ssl_object = self.stream_writer.get_extra_info("ssl_object")
|
||||
|
||||
@ -201,44 +239,6 @@ class AsyncioBackend(ConcurrencyBackend):
|
||||
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
|
||||
)
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
stream: BaseTCPStream,
|
||||
hostname: str,
|
||||
ssl_context: ssl.SSLContext,
|
||||
timeout: TimeoutConfig,
|
||||
) -> BaseTCPStream:
|
||||
|
||||
loop = self.loop
|
||||
if not hasattr(loop, "start_tls"): # pragma: no cover
|
||||
raise NotImplementedError(
|
||||
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
|
||||
)
|
||||
|
||||
assert isinstance(stream, TCPStream)
|
||||
|
||||
stream_reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(stream_reader)
|
||||
transport = stream.stream_writer.transport
|
||||
|
||||
loop_start_tls = loop.start_tls # type: ignore
|
||||
transport = await asyncio.wait_for(
|
||||
loop_start_tls(
|
||||
transport=transport,
|
||||
protocol=protocol,
|
||||
sslcontext=ssl_context,
|
||||
server_hostname=hostname,
|
||||
),
|
||||
timeout=timeout.connect_timeout,
|
||||
)
|
||||
|
||||
stream_reader.set_transport(transport)
|
||||
stream.stream_reader = stream_reader
|
||||
stream.stream_writer = asyncio.StreamWriter(
|
||||
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
|
||||
)
|
||||
return stream
|
||||
|
||||
async def run_in_threadpool(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
|
||||
@ -47,6 +47,11 @@ class BaseTCPStream:
|
||||
def get_http_version(self) -> str:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def start_tls(
|
||||
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
|
||||
) -> "BaseTCPStream":
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def read(
|
||||
self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None
|
||||
) -> bytes:
|
||||
@ -119,15 +124,6 @@ class ConcurrencyBackend:
|
||||
) -> BaseTCPStream:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
stream: BaseTCPStream,
|
||||
hostname: str,
|
||||
ssl_context: ssl.SSLContext,
|
||||
timeout: TimeoutConfig,
|
||||
) -> BaseTCPStream:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
@ -34,6 +34,26 @@ class TCPStream(BaseTCPStream):
|
||||
self.write_buffer = b""
|
||||
self.write_lock = trio.Lock()
|
||||
|
||||
async def start_tls(
|
||||
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
|
||||
) -> BaseTCPStream:
|
||||
# Check that the write buffer is empty. We should never start a TLS stream
|
||||
# while there is still pending data to write.
|
||||
assert self.write_buffer == b""
|
||||
|
||||
connect_timeout = _or_inf(timeout.connect_timeout)
|
||||
ssl_stream = trio.SSLStream(
|
||||
self.stream, ssl_context=ssl_context, server_hostname=hostname
|
||||
)
|
||||
|
||||
with trio.move_on_after(connect_timeout) as cancel_scope:
|
||||
await ssl_stream.do_handshake()
|
||||
|
||||
if cancel_scope.cancelled_caught:
|
||||
raise ConnectTimeout()
|
||||
|
||||
return TCPStream(ssl_stream, self.timeout)
|
||||
|
||||
def get_http_version(self) -> str:
|
||||
if not isinstance(self.stream, trio.SSLStream):
|
||||
return "HTTP/1.1"
|
||||
@ -171,30 +191,6 @@ class TrioBackend(ConcurrencyBackend):
|
||||
|
||||
return TCPStream(stream=stream, timeout=timeout)
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
stream: BaseTCPStream,
|
||||
hostname: str,
|
||||
ssl_context: ssl.SSLContext,
|
||||
timeout: TimeoutConfig,
|
||||
) -> BaseTCPStream:
|
||||
assert isinstance(stream, TCPStream)
|
||||
|
||||
connect_timeout = _or_inf(timeout.connect_timeout)
|
||||
ssl_stream = trio.SSLStream(
|
||||
stream.stream, ssl_context=ssl_context, server_hostname=hostname
|
||||
)
|
||||
|
||||
with trio.move_on_after(connect_timeout) as cancel_scope:
|
||||
await ssl_stream.do_handshake()
|
||||
|
||||
if cancel_scope.cancelled_caught:
|
||||
raise ConnectTimeout()
|
||||
|
||||
stream.stream = ssl_stream
|
||||
|
||||
return stream
|
||||
|
||||
async def run_in_threadpool(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
|
||||
@ -192,11 +192,8 @@ class HTTPProxy(ConnectionPool):
|
||||
f"proxy_url={self.proxy_url!r} "
|
||||
f"origin={origin!r}"
|
||||
)
|
||||
stream = await self.backend.start_tls(
|
||||
stream=stream,
|
||||
hostname=origin.host,
|
||||
ssl_context=ssl_context,
|
||||
timeout=timeout,
|
||||
stream = await stream.start_tls(
|
||||
hostname=origin.host, ssl_context=ssl_context, timeout=timeout
|
||||
)
|
||||
http_version = stream.get_http_version()
|
||||
logger.debug(
|
||||
|
||||
@ -184,16 +184,6 @@ class MockRawSocketBackend:
|
||||
)
|
||||
return self.stream
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
stream: BaseTCPStream,
|
||||
hostname: str,
|
||||
ssl_context: ssl.SSLContext,
|
||||
timeout: TimeoutConfig,
|
||||
) -> BaseTCPStream:
|
||||
self.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
|
||||
return self.stream
|
||||
|
||||
# Defer all other attributes and methods to the underlying backend.
|
||||
def __getattr__(self, name: str) -> typing.Any:
|
||||
return getattr(self.backend, name)
|
||||
@ -203,6 +193,12 @@ class MockRawSocketStream(BaseTCPStream):
|
||||
def __init__(self, backend: MockRawSocketBackend):
|
||||
self.backend = backend
|
||||
|
||||
async def start_tls(
|
||||
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
|
||||
) -> BaseTCPStream:
|
||||
self.backend.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
|
||||
return MockRawSocketStream(self.backend)
|
||||
|
||||
def get_http_version(self) -> str:
|
||||
return "HTTP/1.1"
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ async def test_start_tls_on_socket_stream(https_server, backend, get_cipher):
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert get_cipher(stream) is None
|
||||
|
||||
stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout)
|
||||
stream = await stream.start_tls(https_server.url.host, ctx, timeout)
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert get_cipher(stream) is not None
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user