Fix bug in calculating maximum frame size (#153)
This commit is contained in:
parent
3065dd6b6c
commit
ab9cd0ea11
@ -88,6 +88,9 @@ class BaseEvent:
|
||||
def is_set(self) -> bool:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def clear(self) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def wait(self) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import typing
|
||||
import h2.connection
|
||||
import h2.events
|
||||
|
||||
from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
|
||||
from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag, BaseEvent
|
||||
from ..config import TimeoutConfig, TimeoutTypes
|
||||
from ..models import AsyncRequest, AsyncResponse
|
||||
from ..utils import get_logger
|
||||
@ -28,6 +28,7 @@ class HTTP2Connection:
|
||||
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
|
||||
self.timeout_flags = {} # type: typing.Dict[int, TimeoutFlag]
|
||||
self.initialized = False
|
||||
self.window_update_received = {} # type: typing.Dict[int, BaseEvent]
|
||||
|
||||
async def send(
|
||||
self, request: AsyncRequest, timeout: TimeoutTypes = None
|
||||
@ -42,6 +43,7 @@ class HTTP2Connection:
|
||||
|
||||
self.events[stream_id] = []
|
||||
self.timeout_flags[stream_id] = TimeoutFlag()
|
||||
self.window_update_received[stream_id] = self.backend.create_event()
|
||||
|
||||
task, args = self.send_request_data, [stream_id, request.stream(), timeout]
|
||||
async with self.backend.background_manager(task, *args):
|
||||
@ -108,18 +110,27 @@ class HTTP2Connection:
|
||||
async def send_data(
|
||||
self, stream_id: int, data: bytes, timeout: TimeoutConfig = None
|
||||
) -> None:
|
||||
flow_control = self.h2_state.local_flow_control_window(stream_id)
|
||||
chunk_size = min(len(data), flow_control)
|
||||
for idx in range(0, len(data), chunk_size):
|
||||
chunk = data[idx : idx + chunk_size]
|
||||
|
||||
logger.debug(
|
||||
f"send_data stream_id={stream_id} data=Data(<{len(chunk)} bytes>)"
|
||||
while data:
|
||||
# The data will be divided into frames to send based on the flow control
|
||||
# window and the maximum frame size. Because the flow control window
|
||||
# can decrease in size, even possibly to zero, this will loop until all the
|
||||
# data is sent. In http2 specification:
|
||||
# https://tools.ietf.org/html/rfc7540#section-6.9
|
||||
flow_control = self.h2_state.local_flow_control_window(stream_id)
|
||||
chunk_size = min(
|
||||
len(data), flow_control, self.h2_state.max_outbound_frame_size
|
||||
)
|
||||
|
||||
self.h2_state.send_data(stream_id, chunk)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.stream.write(data_to_send, timeout)
|
||||
if chunk_size == 0:
|
||||
# this means that the flow control window is 0 (either for the stream
|
||||
# or the connection one), and no data can be sent until the flow control
|
||||
# window is updated.
|
||||
await self.window_update_received[stream_id].wait()
|
||||
self.window_update_received[stream_id].clear()
|
||||
else:
|
||||
chunk, data = data[:chunk_size], data[chunk_size:]
|
||||
self.h2_state.send_data(stream_id, chunk)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.stream.write(data_to_send, timeout)
|
||||
|
||||
async def end_stream(self, stream_id: int, timeout: TimeoutConfig = None) -> None:
|
||||
logger.debug(f"end_stream stream_id={stream_id}")
|
||||
@ -148,7 +159,8 @@ class HTTP2Connection:
|
||||
status_code = int(v.decode("ascii", errors="ignore"))
|
||||
elif not k.startswith(b":"):
|
||||
headers.append((k, v))
|
||||
return status_code, headers
|
||||
|
||||
return (status_code, headers)
|
||||
|
||||
async def body_iter(
|
||||
self, stream_id: int, timeout: TimeoutConfig = None
|
||||
@ -156,7 +168,9 @@ class HTTP2Connection:
|
||||
while True:
|
||||
event = await self.receive_event(stream_id, timeout)
|
||||
if isinstance(event, h2.events.DataReceived):
|
||||
self.h2_state.acknowledge_received_data(len(event.data), stream_id)
|
||||
self.h2_state.acknowledge_received_data(
|
||||
event.flow_controlled_length, stream_id
|
||||
)
|
||||
yield event.data
|
||||
elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
|
||||
break
|
||||
@ -173,6 +187,19 @@ class HTTP2Connection:
|
||||
logger.debug(
|
||||
f"receive_event stream_id={event_stream_id} event={event!r}"
|
||||
)
|
||||
if isinstance(event, h2.events.WindowUpdated):
|
||||
if event_stream_id == 0:
|
||||
for window_update_event in self.window_update_received.values():
|
||||
window_update_event.set()
|
||||
else:
|
||||
try:
|
||||
self.window_update_received[event_stream_id].set()
|
||||
except KeyError:
|
||||
# the window_update_received dictionary is only relevant
|
||||
# when sending data, which should never raise a KeyError
|
||||
# here.
|
||||
pass
|
||||
|
||||
if event_stream_id:
|
||||
self.events[event.stream_id].append(event)
|
||||
|
||||
@ -184,6 +211,7 @@ class HTTP2Connection:
|
||||
async def response_closed(self, stream_id: int) -> None:
|
||||
del self.events[stream_id]
|
||||
del self.timeout_flags[stream_id]
|
||||
del self.window_update_received[stream_id]
|
||||
|
||||
if not self.events and self.on_release is not None:
|
||||
await self.on_release()
|
||||
|
||||
@ -65,6 +65,34 @@ async def test_async_http2_post_request(backend):
|
||||
}
|
||||
|
||||
|
||||
def test_http2_large_post_request():
|
||||
backend = MockHTTP2Backend(app=app)
|
||||
|
||||
data = b"a" * 100000
|
||||
with Client(backend=backend) as client:
|
||||
response = client.post("http://example.org", data=data)
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.content) == {
|
||||
"method": "POST",
|
||||
"path": "/",
|
||||
"body": data.decode(),
|
||||
}
|
||||
|
||||
|
||||
async def test_async_http2_large_post_request(backend):
|
||||
backend = MockHTTP2Backend(app=app, backend=backend)
|
||||
|
||||
data = b"a" * 100000
|
||||
async with AsyncClient(backend=backend) as client:
|
||||
response = await client.post("http://example.org", data=data)
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.content) == {
|
||||
"method": "POST",
|
||||
"path": "/",
|
||||
"body": data.decode(),
|
||||
}
|
||||
|
||||
|
||||
def test_http2_multiple_requests():
|
||||
backend = MockHTTP2Backend(app=app)
|
||||
|
||||
|
||||
@ -39,6 +39,8 @@ class MockHTTP2Server(BaseStream):
|
||||
self.buffer = b""
|
||||
self.requests = {}
|
||||
self.close_connection = False
|
||||
self.return_data = {}
|
||||
self.returning = {}
|
||||
|
||||
# Stream interface
|
||||
|
||||
@ -58,8 +60,27 @@ class MockHTTP2Server(BaseStream):
|
||||
self.request_received(event.headers, event.stream_id)
|
||||
elif isinstance(event, h2.events.DataReceived):
|
||||
self.receive_data(event.data, event.stream_id)
|
||||
# This should send an UPDATE_WINDOW for both the stream and the
|
||||
# connection increasing it by the amount
|
||||
# consumed keeping the flow control window constant
|
||||
flow_control_consumed = event.flow_controlled_length
|
||||
if flow_control_consumed > 0:
|
||||
self.conn.increment_flow_control_window(flow_control_consumed)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
self.conn.increment_flow_control_window(
|
||||
flow_control_consumed, event.stream_id
|
||||
)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
self.stream_complete(event.stream_id)
|
||||
elif isinstance(event, h2.events.WindowUpdated):
|
||||
if event.stream_id == 0:
|
||||
for key, value in self.returning.items():
|
||||
if value:
|
||||
self.send_return_data(key)
|
||||
# This will throw an error if the event is for a not-yet created stream
|
||||
elif self.returning[event.stream_id]:
|
||||
self.send_return_data(event.stream_id)
|
||||
|
||||
async def write(self, data: bytes, timeout) -> None:
|
||||
self.write_no_block(data)
|
||||
@ -114,5 +135,28 @@ class MockHTTP2Server(BaseStream):
|
||||
response_headers = [(b":status", status_code_bytes)] + response.headers.raw
|
||||
|
||||
self.conn.send_headers(stream_id, response_headers)
|
||||
self.conn.send_data(stream_id, response.content, end_stream=True)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
self.return_data[stream_id] = response.content
|
||||
self.returning[stream_id] = True
|
||||
self.send_return_data(stream_id)
|
||||
|
||||
def send_return_data(self, stream_id):
|
||||
while self.return_data[stream_id]:
|
||||
flow_control = self.conn.local_flow_control_window(stream_id)
|
||||
chunk_size = min(
|
||||
len(self.return_data[stream_id]),
|
||||
flow_control,
|
||||
self.conn.max_outbound_frame_size,
|
||||
)
|
||||
if chunk_size == 0:
|
||||
return
|
||||
else:
|
||||
chunk, self.return_data[stream_id] = (
|
||||
self.return_data[stream_id][:chunk_size],
|
||||
self.return_data[stream_id][chunk_size:],
|
||||
)
|
||||
self.conn.send_data(stream_id, chunk)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
self.returning[stream_id] = False
|
||||
self.conn.end_stream(stream_id)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user