HTTP/2 refactoring (#612)

* HTTP/2 refactoring

* Clean up flow control

* Remove extra blank line
This commit is contained in:
Tom Christie 2019-12-07 14:14:09 +00:00 committed by GitHub
parent bc54dd0399
commit f57bb2f142
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,3 @@
import functools
import typing
import h2.connection
@ -9,7 +8,6 @@ from ..concurrency.base import (
BaseEvent,
BaseSocketStream,
ConcurrencyBackend,
TimeoutFlag,
lookup_backend,
)
from ..config import Timeout
@ -32,10 +30,10 @@ class HTTP2Connection:
self.socket = socket
self.backend = lookup_backend(backend)
self.on_release = on_release
self.h2_state = h2.connection.H2Connection()
self.state = h2.connection.H2Connection()
self.streams = {} # type: typing.Dict[int, HTTP2Stream]
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
self.timeout_flags = {} # type: typing.Dict[int, TimeoutFlag]
self.window_update_received = {} # type: typing.Dict[int, BaseEvent]
self.init_started = False
@ -54,54 +52,28 @@ class HTTP2Connection:
# The very first stream is responsible for initiating the connection.
self.init_started = True
await self.send_connection_init(timeout)
stream_id = self.state.get_next_available_stream_id()
self.init_complete.set()
else:
# All other streams need to wait until the connection is established.
await self.init_complete.wait()
stream_id = self.state.get_next_available_stream_id()
stream_id = await self.send_headers(request, timeout)
stream = HTTP2Stream(stream_id=stream_id, connection=self, state=self.state)
self.streams[stream_id] = stream
self.events[stream_id] = []
self.timeout_flags[stream_id] = TimeoutFlag()
self.window_update_received[stream_id] = self.backend.create_event()
status_code: typing.Optional[int] = None
headers: typing.Optional[list] = None
async def receive_response(stream_id: int, timeout: Timeout) -> None:
nonlocal status_code, headers
status_code, headers = await self.receive_response(stream_id, timeout)
await self.backend.fork(
self.send_request_data,
[stream_id, request.stream(), timeout],
receive_response,
[stream_id, timeout],
)
assert status_code is not None
assert headers is not None
content = self.body_iter(stream_id, timeout)
on_close = functools.partial(self.response_closed, stream_id=stream_id)
return Response(
status_code=status_code,
http_version="HTTP/2",
headers=headers,
content=content,
on_close=on_close,
request=request,
)
async def close(self) -> None:
await self.socket.close()
return await stream.send(request, timeout)
async def send_connection_init(self, timeout: Timeout) -> None:
"""
The HTTP/2 connection requires some initial setup before we can start
using individual request/response streams on it.
"""
# Need to set these manually here instead of manipulating via
# __setitem__() otherwise the H2Connection will emit SettingsUpdate
# frames in addition to sending the undesired defaults.
self.h2_state.local_settings = Settings(
self.state.local_settings = Settings(
client=True,
initial_values={
# Disable PUSH_PROMISE frames from the server since we don't do anything
@ -116,16 +88,113 @@ class HTTP2Connection:
# Some websites (*cough* Yahoo *cough*) balk at this setting being
# present in the initial handshake since it's not defined in the original
# RFC despite the RFC mandating ignoring settings you don't know about.
del self.h2_state.local_settings[
h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL
]
del self.state.local_settings[h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL]
self.h2_state.initiate_connection()
data_to_send = self.h2_state.data_to_send()
self.state.initiate_connection()
data_to_send = self.state.data_to_send()
await self.socket.write(data_to_send, timeout)
async def send_headers(self, request: Request, timeout: Timeout) -> int:
stream_id = self.h2_state.get_next_available_stream_id()
@property
def is_closed(self) -> bool:
return False
def is_connection_dropped(self) -> bool:
return self.socket.is_connection_dropped()
async def close(self) -> None:
await self.socket.close()
async def wait_for_outgoing_flow(self, stream_id: int, timeout: Timeout) -> int:
"""
Returns the maximum allowable outgoing flow for a given stream.
If the allowable flow is zero, then waits on the network until
WindowUpdated frames have increased the flow rate.
https://tools.ietf.org/html/rfc7540#section-6.9
"""
local_flow = self.state.local_flow_control_window(stream_id)
connection_flow = self.state.max_outbound_frame_size
flow = min(local_flow, connection_flow)
while flow == 0:
await self.receive_events(timeout)
local_flow = self.state.local_flow_control_window(stream_id)
connection_flow = self.state.max_outbound_frame_size
flow = min(local_flow, connection_flow)
return flow
async def wait_for_event(self, stream_id: int, timeout: Timeout) -> h2.events.Event:
"""
Returns the next event for a given stream.
If no events are available yet, then waits on the network until
an event is available.
"""
while not self.events[stream_id]:
await self.receive_events(timeout)
return self.events[stream_id].pop(0)
async def receive_events(self, timeout: Timeout) -> None:
"""
Read some data from the network, and update the H2 state.
"""
data = await self.socket.read(self.READ_NUM_BYTES, timeout)
events = self.state.receive_data(data)
for event in events:
event_stream_id = getattr(event, "stream_id", 0)
logger.trace(f"receive_event stream_id={event_stream_id} event={event!r}")
if hasattr(event, "error_code"):
raise ProtocolError(event)
if event_stream_id in self.events:
self.events[event_stream_id].append(event)
data_to_send = self.state.data_to_send()
await self.socket.write(data_to_send, timeout)
async def send_outgoing_data(self, timeout: Timeout) -> None:
data_to_send = self.state.data_to_send()
if data_to_send:
await self.socket.write(data_to_send, timeout)
async def close_stream(self, stream_id: int) -> None:
del self.streams[stream_id]
del self.events[stream_id]
if not self.streams and self.on_release is not None:
await self.on_release()
class HTTP2Stream:
def __init__(
self,
stream_id: int,
connection: HTTP2Connection,
state: h2.connection.H2Connection,
) -> None:
self.stream_id = stream_id
self.connection = connection
self.state = state
async def send(self, request: Request, timeout: Timeout) -> Response:
# Send the request.
await self.send_headers(request, timeout)
await self.send_body(request, timeout)
# Receive the response.
status_code, headers = await self.receive_response(timeout)
content = self.body_iter(timeout)
return Response(
status_code=status_code,
http_version="HTTP/2",
headers=headers,
content=content,
on_close=self.close,
request=request,
)
async def send_headers(self, request: Request, timeout: Timeout) -> None:
headers = [
(b":method", request.method.encode("ascii")),
(b":authority", request.url.authority.encode("ascii")),
@ -135,67 +204,37 @@ class HTTP2Connection:
logger.trace(
f"send_headers "
f"stream_id={stream_id} "
f"stream_id={self.stream_id} "
f"method={request.method!r} "
f"target={request.url.full_path!r} "
f"headers={headers!r}"
)
self.h2_state.send_headers(stream_id, headers)
data_to_send = self.h2_state.data_to_send()
await self.socket.write(data_to_send, timeout)
return stream_id
self.state.send_headers(self.stream_id, headers)
await self.connection.send_outgoing_data(timeout)
async def send_request_data(
self, stream_id: int, stream: typing.AsyncIterator[bytes], timeout: Timeout,
) -> None:
try:
async for data in stream:
await self.send_data(stream_id, data, timeout)
await self.end_stream(stream_id, timeout)
finally:
# Once we've sent the request we should enable read timeouts.
self.timeout_flags[stream_id].set_read_timeouts()
async def send_data(self, stream_id: int, data: bytes, timeout: Timeout) -> None:
while data:
# The data will be divided into frames to send based on the flow control
# window and the maximum frame size. Because the flow control window
# can decrease in size, even possibly to zero, this will loop until all the
# data is sent. In http2 specification:
# https://tools.ietf.org/html/rfc7540#section-6.9
flow_control = self.h2_state.local_flow_control_window(stream_id)
chunk_size = min(
len(data), flow_control, self.h2_state.max_outbound_frame_size
)
if chunk_size == 0:
# this means that the flow control window is 0 (either for the stream
# or the connection one), and no data can be sent until the flow control
# window is updated.
await self.window_update_received[stream_id].wait()
self.window_update_received[stream_id].clear()
else:
async def send_body(self, request: Request, timeout: Timeout) -> None:
logger.trace(f"send_body stream_id={self.stream_id}")
async for data in request.stream():
while data:
max_flow = await self.connection.wait_for_outgoing_flow(
self.stream_id, timeout
)
chunk_size = min(len(data), max_flow)
chunk, data = data[:chunk_size], data[chunk_size:]
self.h2_state.send_data(stream_id, chunk)
data_to_send = self.h2_state.data_to_send()
await self.socket.write(data_to_send, timeout)
self.state.send_data(self.stream_id, chunk)
await self.connection.send_outgoing_data(timeout)
async def end_stream(self, stream_id: int, timeout: Timeout) -> None:
logger.trace(f"end_stream stream_id={stream_id}")
self.h2_state.end_stream(stream_id)
data_to_send = self.h2_state.data_to_send()
await self.socket.write(data_to_send, timeout)
self.state.end_stream(self.stream_id)
await self.connection.send_outgoing_data(timeout)
async def receive_response(
self, stream_id: int, timeout: Timeout
self, timeout: Timeout
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
"""
Read the response status and headers from the network.
"""
while True:
event = await self.receive_event(stream_id, timeout)
# As soon as we start seeing response events, we should enable
# read timeouts, if we haven't already.
self.timeout_flags[stream_id].set_read_timeouts()
event = await self.connection.wait_for_event(self.stream_id, timeout)
if isinstance(event, h2.events.ResponseReceived):
break
@ -209,65 +248,17 @@ class HTTP2Connection:
return (status_code, headers)
async def body_iter(
self, stream_id: int, timeout: Timeout
) -> typing.AsyncIterator[bytes]:
async def body_iter(self, timeout: Timeout) -> typing.AsyncIterator[bytes]:
while True:
event = await self.receive_event(stream_id, timeout)
event = await self.connection.wait_for_event(self.stream_id, timeout)
if isinstance(event, h2.events.DataReceived):
self.h2_state.acknowledge_received_data(
event.flow_controlled_length, stream_id
self.state.acknowledge_received_data(
event.flow_controlled_length, self.stream_id
)
await self.connection.send_outgoing_data(timeout)
yield event.data
elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
break
async def receive_event(self, stream_id: int, timeout: Timeout) -> h2.events.Event:
while not self.events[stream_id]:
flag = self.timeout_flags[stream_id]
data = await self.socket.read(self.READ_NUM_BYTES, timeout, flag=flag)
events = self.h2_state.receive_data(data)
for event in events:
event_stream_id = getattr(event, "stream_id", 0)
logger.trace(
f"receive_event stream_id={event_stream_id} event={event!r}"
)
if hasattr(event, "error_code"):
raise ProtocolError(event)
if isinstance(event, h2.events.WindowUpdated):
if event_stream_id == 0:
for window_update_event in self.window_update_received.values():
window_update_event.set()
else:
try:
self.window_update_received[event_stream_id].set()
except KeyError: # pragma: no cover
# the window_update_received dictionary is only relevant
# when sending data, which should never raise a KeyError
# here.
pass
if event_stream_id:
self.events[event.stream_id].append(event)
data_to_send = self.h2_state.data_to_send()
await self.socket.write(data_to_send, timeout)
return self.events[stream_id].pop(0)
async def response_closed(self, stream_id: int) -> None:
del self.events[stream_id]
del self.timeout_flags[stream_id]
del self.window_update_received[stream_id]
if not self.events and self.on_release is not None:
await self.on_release()
@property
def is_closed(self) -> bool:
return False
def is_connection_dropped(self) -> bool:
return self.socket.is_connection_dropped()
async def close(self) -> None:
await self.connection.close_stream(self.stream_id)