HTTP/2 refactoring (#612)
* HTTP/2 refactoring * Clean up flow control * Remove extra blank line
This commit is contained in:
parent
bc54dd0399
commit
f57bb2f142
@ -1,4 +1,3 @@
|
||||
import functools
|
||||
import typing
|
||||
|
||||
import h2.connection
|
||||
@ -9,7 +8,6 @@ from ..concurrency.base import (
|
||||
BaseEvent,
|
||||
BaseSocketStream,
|
||||
ConcurrencyBackend,
|
||||
TimeoutFlag,
|
||||
lookup_backend,
|
||||
)
|
||||
from ..config import Timeout
|
||||
@ -32,10 +30,10 @@ class HTTP2Connection:
|
||||
self.socket = socket
|
||||
self.backend = lookup_backend(backend)
|
||||
self.on_release = on_release
|
||||
self.h2_state = h2.connection.H2Connection()
|
||||
self.state = h2.connection.H2Connection()
|
||||
|
||||
self.streams = {} # type: typing.Dict[int, HTTP2Stream]
|
||||
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
|
||||
self.timeout_flags = {} # type: typing.Dict[int, TimeoutFlag]
|
||||
self.window_update_received = {} # type: typing.Dict[int, BaseEvent]
|
||||
|
||||
self.init_started = False
|
||||
|
||||
@ -54,54 +52,28 @@ class HTTP2Connection:
|
||||
# The very first stream is responsible for initiating the connection.
|
||||
self.init_started = True
|
||||
await self.send_connection_init(timeout)
|
||||
stream_id = self.state.get_next_available_stream_id()
|
||||
self.init_complete.set()
|
||||
else:
|
||||
# All other streams need to wait until the connection is established.
|
||||
await self.init_complete.wait()
|
||||
stream_id = self.state.get_next_available_stream_id()
|
||||
|
||||
stream_id = await self.send_headers(request, timeout)
|
||||
|
||||
stream = HTTP2Stream(stream_id=stream_id, connection=self, state=self.state)
|
||||
self.streams[stream_id] = stream
|
||||
self.events[stream_id] = []
|
||||
self.timeout_flags[stream_id] = TimeoutFlag()
|
||||
self.window_update_received[stream_id] = self.backend.create_event()
|
||||
|
||||
status_code: typing.Optional[int] = None
|
||||
headers: typing.Optional[list] = None
|
||||
|
||||
async def receive_response(stream_id: int, timeout: Timeout) -> None:
|
||||
nonlocal status_code, headers
|
||||
status_code, headers = await self.receive_response(stream_id, timeout)
|
||||
|
||||
await self.backend.fork(
|
||||
self.send_request_data,
|
||||
[stream_id, request.stream(), timeout],
|
||||
receive_response,
|
||||
[stream_id, timeout],
|
||||
)
|
||||
|
||||
assert status_code is not None
|
||||
assert headers is not None
|
||||
|
||||
content = self.body_iter(stream_id, timeout)
|
||||
on_close = functools.partial(self.response_closed, stream_id=stream_id)
|
||||
|
||||
return Response(
|
||||
status_code=status_code,
|
||||
http_version="HTTP/2",
|
||||
headers=headers,
|
||||
content=content,
|
||||
on_close=on_close,
|
||||
request=request,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.socket.close()
|
||||
return await stream.send(request, timeout)
|
||||
|
||||
async def send_connection_init(self, timeout: Timeout) -> None:
|
||||
"""
|
||||
The HTTP/2 connection requires some initial setup before we can start
|
||||
using individual request/response streams on it.
|
||||
"""
|
||||
|
||||
# Need to set these manually here instead of manipulating via
|
||||
# __setitem__() otherwise the H2Connection will emit SettingsUpdate
|
||||
# frames in addition to sending the undesired defaults.
|
||||
self.h2_state.local_settings = Settings(
|
||||
self.state.local_settings = Settings(
|
||||
client=True,
|
||||
initial_values={
|
||||
# Disable PUSH_PROMISE frames from the server since we don't do anything
|
||||
@ -116,16 +88,113 @@ class HTTP2Connection:
|
||||
# Some websites (*cough* Yahoo *cough*) balk at this setting being
|
||||
# present in the initial handshake since it's not defined in the original
|
||||
# RFC despite the RFC mandating ignoring settings you don't know about.
|
||||
del self.h2_state.local_settings[
|
||||
h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL
|
||||
]
|
||||
del self.state.local_settings[h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL]
|
||||
|
||||
self.h2_state.initiate_connection()
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
self.state.initiate_connection()
|
||||
data_to_send = self.state.data_to_send()
|
||||
await self.socket.write(data_to_send, timeout)
|
||||
|
||||
async def send_headers(self, request: Request, timeout: Timeout) -> int:
|
||||
stream_id = self.h2_state.get_next_available_stream_id()
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
return self.socket.is_connection_dropped()
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.socket.close()
|
||||
|
||||
async def wait_for_outgoing_flow(self, stream_id: int, timeout: Timeout) -> int:
|
||||
"""
|
||||
Returns the maximum allowable outgoing flow for a given stream.
|
||||
|
||||
If the allowable flow is zero, then waits on the network until
|
||||
WindowUpdated frames have increased the flow rate.
|
||||
|
||||
https://tools.ietf.org/html/rfc7540#section-6.9
|
||||
"""
|
||||
local_flow = self.state.local_flow_control_window(stream_id)
|
||||
connection_flow = self.state.max_outbound_frame_size
|
||||
flow = min(local_flow, connection_flow)
|
||||
while flow == 0:
|
||||
await self.receive_events(timeout)
|
||||
local_flow = self.state.local_flow_control_window(stream_id)
|
||||
connection_flow = self.state.max_outbound_frame_size
|
||||
flow = min(local_flow, connection_flow)
|
||||
return flow
|
||||
|
||||
async def wait_for_event(self, stream_id: int, timeout: Timeout) -> h2.events.Event:
|
||||
"""
|
||||
Returns the next event for a given stream.
|
||||
|
||||
If no events are available yet, then waits on the network until
|
||||
an event is available.
|
||||
"""
|
||||
while not self.events[stream_id]:
|
||||
await self.receive_events(timeout)
|
||||
return self.events[stream_id].pop(0)
|
||||
|
||||
async def receive_events(self, timeout: Timeout) -> None:
|
||||
"""
|
||||
Read some data from the network, and update the H2 state.
|
||||
"""
|
||||
data = await self.socket.read(self.READ_NUM_BYTES, timeout)
|
||||
events = self.state.receive_data(data)
|
||||
for event in events:
|
||||
event_stream_id = getattr(event, "stream_id", 0)
|
||||
logger.trace(f"receive_event stream_id={event_stream_id} event={event!r}")
|
||||
|
||||
if hasattr(event, "error_code"):
|
||||
raise ProtocolError(event)
|
||||
|
||||
if event_stream_id in self.events:
|
||||
self.events[event_stream_id].append(event)
|
||||
|
||||
data_to_send = self.state.data_to_send()
|
||||
await self.socket.write(data_to_send, timeout)
|
||||
|
||||
async def send_outgoing_data(self, timeout: Timeout) -> None:
|
||||
data_to_send = self.state.data_to_send()
|
||||
if data_to_send:
|
||||
await self.socket.write(data_to_send, timeout)
|
||||
|
||||
async def close_stream(self, stream_id: int) -> None:
|
||||
del self.streams[stream_id]
|
||||
del self.events[stream_id]
|
||||
|
||||
if not self.streams and self.on_release is not None:
|
||||
await self.on_release()
|
||||
|
||||
|
||||
class HTTP2Stream:
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: int,
|
||||
connection: HTTP2Connection,
|
||||
state: h2.connection.H2Connection,
|
||||
) -> None:
|
||||
self.stream_id = stream_id
|
||||
self.connection = connection
|
||||
self.state = state
|
||||
|
||||
async def send(self, request: Request, timeout: Timeout) -> Response:
|
||||
# Send the request.
|
||||
await self.send_headers(request, timeout)
|
||||
await self.send_body(request, timeout)
|
||||
|
||||
# Receive the response.
|
||||
status_code, headers = await self.receive_response(timeout)
|
||||
content = self.body_iter(timeout)
|
||||
return Response(
|
||||
status_code=status_code,
|
||||
http_version="HTTP/2",
|
||||
headers=headers,
|
||||
content=content,
|
||||
on_close=self.close,
|
||||
request=request,
|
||||
)
|
||||
|
||||
async def send_headers(self, request: Request, timeout: Timeout) -> None:
|
||||
headers = [
|
||||
(b":method", request.method.encode("ascii")),
|
||||
(b":authority", request.url.authority.encode("ascii")),
|
||||
@ -135,67 +204,37 @@ class HTTP2Connection:
|
||||
|
||||
logger.trace(
|
||||
f"send_headers "
|
||||
f"stream_id={stream_id} "
|
||||
f"stream_id={self.stream_id} "
|
||||
f"method={request.method!r} "
|
||||
f"target={request.url.full_path!r} "
|
||||
f"headers={headers!r}"
|
||||
)
|
||||
self.h2_state.send_headers(stream_id, headers)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.socket.write(data_to_send, timeout)
|
||||
return stream_id
|
||||
self.state.send_headers(self.stream_id, headers)
|
||||
await self.connection.send_outgoing_data(timeout)
|
||||
|
||||
async def send_request_data(
|
||||
self, stream_id: int, stream: typing.AsyncIterator[bytes], timeout: Timeout,
|
||||
) -> None:
|
||||
try:
|
||||
async for data in stream:
|
||||
await self.send_data(stream_id, data, timeout)
|
||||
await self.end_stream(stream_id, timeout)
|
||||
finally:
|
||||
# Once we've sent the request we should enable read timeouts.
|
||||
self.timeout_flags[stream_id].set_read_timeouts()
|
||||
|
||||
async def send_data(self, stream_id: int, data: bytes, timeout: Timeout) -> None:
|
||||
while data:
|
||||
# The data will be divided into frames to send based on the flow control
|
||||
# window and the maximum frame size. Because the flow control window
|
||||
# can decrease in size, even possibly to zero, this will loop until all the
|
||||
# data is sent. In http2 specification:
|
||||
# https://tools.ietf.org/html/rfc7540#section-6.9
|
||||
flow_control = self.h2_state.local_flow_control_window(stream_id)
|
||||
chunk_size = min(
|
||||
len(data), flow_control, self.h2_state.max_outbound_frame_size
|
||||
)
|
||||
if chunk_size == 0:
|
||||
# this means that the flow control window is 0 (either for the stream
|
||||
# or the connection one), and no data can be sent until the flow control
|
||||
# window is updated.
|
||||
await self.window_update_received[stream_id].wait()
|
||||
self.window_update_received[stream_id].clear()
|
||||
else:
|
||||
async def send_body(self, request: Request, timeout: Timeout) -> None:
|
||||
logger.trace(f"send_body stream_id={self.stream_id}")
|
||||
async for data in request.stream():
|
||||
while data:
|
||||
max_flow = await self.connection.wait_for_outgoing_flow(
|
||||
self.stream_id, timeout
|
||||
)
|
||||
chunk_size = min(len(data), max_flow)
|
||||
chunk, data = data[:chunk_size], data[chunk_size:]
|
||||
self.h2_state.send_data(stream_id, chunk)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.socket.write(data_to_send, timeout)
|
||||
self.state.send_data(self.stream_id, chunk)
|
||||
await self.connection.send_outgoing_data(timeout)
|
||||
|
||||
async def end_stream(self, stream_id: int, timeout: Timeout) -> None:
|
||||
logger.trace(f"end_stream stream_id={stream_id}")
|
||||
self.h2_state.end_stream(stream_id)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.socket.write(data_to_send, timeout)
|
||||
self.state.end_stream(self.stream_id)
|
||||
await self.connection.send_outgoing_data(timeout)
|
||||
|
||||
async def receive_response(
|
||||
self, stream_id: int, timeout: Timeout
|
||||
self, timeout: Timeout
|
||||
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
|
||||
"""
|
||||
Read the response status and headers from the network.
|
||||
"""
|
||||
while True:
|
||||
event = await self.receive_event(stream_id, timeout)
|
||||
# As soon as we start seeing response events, we should enable
|
||||
# read timeouts, if we haven't already.
|
||||
self.timeout_flags[stream_id].set_read_timeouts()
|
||||
event = await self.connection.wait_for_event(self.stream_id, timeout)
|
||||
if isinstance(event, h2.events.ResponseReceived):
|
||||
break
|
||||
|
||||
@ -209,65 +248,17 @@ class HTTP2Connection:
|
||||
|
||||
return (status_code, headers)
|
||||
|
||||
async def body_iter(
|
||||
self, stream_id: int, timeout: Timeout
|
||||
) -> typing.AsyncIterator[bytes]:
|
||||
async def body_iter(self, timeout: Timeout) -> typing.AsyncIterator[bytes]:
|
||||
while True:
|
||||
event = await self.receive_event(stream_id, timeout)
|
||||
event = await self.connection.wait_for_event(self.stream_id, timeout)
|
||||
if isinstance(event, h2.events.DataReceived):
|
||||
self.h2_state.acknowledge_received_data(
|
||||
event.flow_controlled_length, stream_id
|
||||
self.state.acknowledge_received_data(
|
||||
event.flow_controlled_length, self.stream_id
|
||||
)
|
||||
await self.connection.send_outgoing_data(timeout)
|
||||
yield event.data
|
||||
elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
|
||||
break
|
||||
|
||||
async def receive_event(self, stream_id: int, timeout: Timeout) -> h2.events.Event:
|
||||
while not self.events[stream_id]:
|
||||
flag = self.timeout_flags[stream_id]
|
||||
data = await self.socket.read(self.READ_NUM_BYTES, timeout, flag=flag)
|
||||
events = self.h2_state.receive_data(data)
|
||||
for event in events:
|
||||
event_stream_id = getattr(event, "stream_id", 0)
|
||||
logger.trace(
|
||||
f"receive_event stream_id={event_stream_id} event={event!r}"
|
||||
)
|
||||
|
||||
if hasattr(event, "error_code"):
|
||||
raise ProtocolError(event)
|
||||
|
||||
if isinstance(event, h2.events.WindowUpdated):
|
||||
if event_stream_id == 0:
|
||||
for window_update_event in self.window_update_received.values():
|
||||
window_update_event.set()
|
||||
else:
|
||||
try:
|
||||
self.window_update_received[event_stream_id].set()
|
||||
except KeyError: # pragma: no cover
|
||||
# the window_update_received dictionary is only relevant
|
||||
# when sending data, which should never raise a KeyError
|
||||
# here.
|
||||
pass
|
||||
|
||||
if event_stream_id:
|
||||
self.events[event.stream_id].append(event)
|
||||
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.socket.write(data_to_send, timeout)
|
||||
|
||||
return self.events[stream_id].pop(0)
|
||||
|
||||
async def response_closed(self, stream_id: int) -> None:
|
||||
del self.events[stream_id]
|
||||
del self.timeout_flags[stream_id]
|
||||
del self.window_update_received[stream_id]
|
||||
|
||||
if not self.events and self.on_release is not None:
|
||||
await self.on_release()
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
return self.socket.is_connection_dropped()
|
||||
async def close(self) -> None:
|
||||
await self.connection.close_stream(self.stream_id)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user