First pass at HTTP/2 support

This commit is contained in:
Tom Christie 2019-04-24 15:48:18 +01:00
parent 91a2a1b896
commit 53f3dc4a66
10 changed files with 322 additions and 70 deletions

View File

@ -1,4 +1,5 @@
from .config import PoolLimits, SSLConfig, TimeoutConfig
from .connection import HTTPConnection
from .connectionpool import ConnectionPool
from .datastructures import URL, Origin, Request, Response
from .exceptions import (
@ -10,6 +11,7 @@ from .exceptions import (
StreamConsumed,
Timeout,
)
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
from .sync import SyncClient, SyncConnectionPool

View File

@ -73,7 +73,23 @@ class SSLConfig:
"invalid path: {}".format(self.verify)
)
context = ssl.create_default_context()
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_COMPRESSION
# RFC 7540 Section 9.2.2: "deployments of HTTP/2 that use TLS 1.2 MUST
# support TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256". In practice, the
# blacklist defined in this section allows only the AES GCM and ChaCha20
# cipher suites with ephemeral key negotiation.
context.set_ciphers("ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20")
if ssl.HAS_ALPN:
context.set_alpn_protocols(["h2", "http/1.1"])
if ssl.HAS_NPN:
context.set_npn_protocols(["h2", "http/1.1"])
if os.path.isfile(ca_bundle_path):
context.load_verify_locations(cafile=ca_bundle_path)
elif os.path.isdir(ca_bundle_path):

106
httpcore/connection.py Normal file
View File

@ -0,0 +1,106 @@
import asyncio
import typing
import h2.connection
import h11
from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
from .datastructures import Client, Origin, Request, Response
from .exceptions import ConnectTimeout
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
class HTTPConnection(Client):
def __init__(
self,
origin: typing.Union[str, Origin],
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
on_release: typing.Callable = None,
):
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = ssl
self.timeout = timeout
self.on_release = on_release
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
async def send(
self,
request: Request,
*,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
) -> Response:
if self.h11_connection is None and self.h2_connection is None:
if ssl is None:
ssl = self.ssl
if timeout is None:
timeout = self.timeout
reader, writer, protocol = await self.connect(ssl, timeout)
if protocol == "h2":
self.h2_connection = HTTP2Connection(
reader,
writer,
origin=self.origin,
timeout=self.timeout,
on_release=self.on_release,
)
else:
self.h11_connection = HTTP11Connection(
reader,
writer,
origin=self.origin,
timeout=self.timeout,
on_release=self.on_release,
)
if self.h2_connection is not None:
response = await self.h2_connection.send(request, ssl=ssl, timeout=timeout)
else:
assert self.h11_connection is not None
response = await self.h11_connection.send(request, ssl=ssl, timeout=timeout)
return response
async def close(self) -> None:
if self.h2_connection is not None:
await self.h2_connection.close()
else:
assert self.h11_connection is not None
await self.h11_connection.close()
@property
def is_closed(self) -> bool:
if self.h2_connection is not None:
return self.h2_connection.is_closed
else:
assert self.h11_connection is not None
return self.h11_connection.is_closed
async def connect(
self, ssl: SSLConfig, timeout: TimeoutConfig
) -> typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter, str]:
hostname = self.origin.hostname
port = self.origin.port
ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
try:
reader, writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
timeout.connect_timeout,
)
except asyncio.TimeoutError:
raise ConnectTimeout()
ssl_object = writer.get_extra_info("ssl_object")
if ssl_object is None:
protocol = "http/1.1"
else:
protocol = ssl_object.selected_alpn_protocol()
if protocol is None:
protocol = ssl_object.selected_npn_protocol()
return (reader, writer, protocol)

View File

@ -10,9 +10,9 @@ from .config import (
SSLConfig,
TimeoutConfig,
)
from .connection import HTTPConnection
from .datastructures import Client, Origin, Request, Response
from .exceptions import PoolTimeout
from .http11 import HTTP11Connection
class ConnectionPool(Client):
@ -31,7 +31,7 @@ class ConnectionPool(Client):
self.num_keepalive_connections = 0
self._keepalive_connections = (
{}
) # type: typing.Dict[Origin, typing.List[HTTP11Connection]]
) # type: typing.Dict[Origin, typing.List[HTTPConnection]]
self._max_connections = ConnectionSemaphore(
max_connections=self.limits.hard_limit
)
@ -53,7 +53,7 @@ class ConnectionPool(Client):
async def acquire_connection(
self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None
) -> HTTP11Connection:
) -> HTTPConnection:
try:
connection = self._keepalive_connections[origin].pop()
if not self._keepalive_connections[origin]:
@ -71,7 +71,7 @@ class ConnectionPool(Client):
await asyncio.wait_for(self._max_connections.acquire(), pool_timeout)
except asyncio.TimeoutError:
raise PoolTimeout()
connection = HTTP11Connection(
connection = HTTPConnection(
origin,
ssl=self.ssl,
timeout=self.timeout,
@ -81,7 +81,7 @@ class ConnectionPool(Client):
return connection
async def release_connection(self, connection: HTTP11Connection) -> None:
async def release_connection(self, connection: HTTPConnection) -> None:
if connection.is_closed:
self._max_connections.release()
self.num_active_connections -= 1

View File

@ -52,7 +52,7 @@ class URL:
return port
@property
def target(self) -> str:
def full_path(self) -> str:
path = self.path or "/"
query = self.query
if query:
@ -138,10 +138,11 @@ class Request:
return headers
async def stream(self) -> typing.AsyncIterator[bytes]:
assert self.is_streaming
async for part in self.body_aiter:
yield part
if self.is_streaming:
async for part in self.body_aiter:
yield part
elif self.body:
yield self.body
class Response:
@ -150,6 +151,7 @@ class Response:
status_code: int,
*,
reason: typing.Optional[str] = None,
protocol: typing.Optional[str] = None,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
on_close: typing.Callable = None,
@ -162,6 +164,7 @@ class Response:
self.reason = ""
else:
self.reason = reason
self.protocol = protocol
self.headers = list(headers)
self.on_close = on_close
self.is_closed = False

View File

@ -20,55 +20,43 @@ H11Event = typing.Union[
class HTTP11Connection(Client):
def __init__(
self,
origin: typing.Union[str, Origin],
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
origin: Origin,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
on_release: typing.Callable = None,
):
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = ssl
self.origin = origin
self.reader = reader
self.writer = writer
self.timeout = timeout
self.on_release = on_release
self._reader = None
self._writer = None
self._h11_state = h11.Connection(our_role=h11.CLIENT)
self.h11_state = h11.Connection(our_role=h11.CLIENT)
@property
def is_closed(self) -> bool:
return self._h11_state.our_state in (h11.CLOSED, h11.ERROR)
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
async def send(
self,
request: Request,
*,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None
) -> Response:
assert request.url.origin == self.origin
if ssl is None:
ssl = self.ssl
if timeout is None:
timeout = self.timeout
# Make the connection
if self._reader is None:
await self._connect(ssl, timeout)
#  Start sending the request.
method = request.method.encode()
target = request.url.target
target = request.url.full_path
headers = request.headers
event = h11.Request(method=method, target=target, headers=headers)
await self._send_event(event)
# Send the request body.
if request.is_streaming:
async for data in request.stream():
event = h11.Data(data=data)
await self._send_event(event)
elif request.body:
event = h11.Data(data=request.body)
async for data in request.stream():
event = h11.Data(data=data)
await self._send_event(event)
# Finalize sending the request.
@ -79,32 +67,22 @@ class HTTP11Connection(Client):
event = await self._receive_event(timeout)
if isinstance(event, h11.InformationalResponse):
event = await self._receive_event(timeout)
assert isinstance(event, h11.Response)
reason = event.reason.decode("latin1")
status_code = event.status_code
headers = event.headers
body = self._body_iter(timeout)
return Response(
status_code=status_code,
reason=reason,
protocol="HTTP/1.1",
headers=headers,
body=body,
on_close=self._release,
)
async def _connect(self, ssl: SSLConfig, timeout: TimeoutConfig) -> None:
hostname = self.origin.hostname
port = self.origin.port
ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
try:
self._reader, self._writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
timeout.connect_timeout,
)
except asyncio.TimeoutError:
raise ConnectTimeout()
async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]:
event = await self._receive_event(timeout)
while isinstance(event, h11.Data):
@ -113,36 +91,30 @@ class HTTP11Connection(Client):
assert isinstance(event, h11.EndOfMessage)
async def _send_event(self, event: H11Event) -> None:
assert self._writer is not None
data = self._h11_state.send(event)
self._writer.write(data)
data = self.h11_state.send(event)
self.writer.write(data)
async def _receive_event(self, timeout: TimeoutConfig) -> H11Event:
assert self._reader is not None
event = self._h11_state.next_event()
event = self.h11_state.next_event()
while event is h11.NEED_DATA:
try:
data = await asyncio.wait_for(
self._reader.read(2048), timeout.read_timeout
self.reader.read(2048), timeout.read_timeout
)
except asyncio.TimeoutError:
raise ReadTimeout()
self._h11_state.receive_data(data)
event = self._h11_state.next_event()
self.h11_state.receive_data(data)
event = self.h11_state.next_event()
return event
async def _release(self) -> None:
assert self._writer is not None
if (
self._h11_state.our_state is h11.DONE
and self._h11_state.their_state is h11.DONE
self.h11_state.our_state is h11.DONE
and self.h11_state.their_state is h11.DONE
):
self._h11_state.start_next_cycle()
self.h11_state.start_next_cycle()
else:
await self.close()
@ -153,11 +125,11 @@ class HTTP11Connection(Client):
event = h11.ConnectionClosed()
try:
# If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
self._h11_state.send(event)
self.h11_state.send(event)
except h11.ProtocolError:
# If we're in some other state then it's a premature close,
# and we'll end up in h11.ERROR.
pass
if self._writer is not None:
self._writer.close()
if self.writer is not None:
self.writer.close()

152
httpcore/http2.py Normal file
View File

@ -0,0 +1,152 @@
import asyncio
import typing
import h2.connection
import h2.events
from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
from .datastructures import Client, Origin, Request, Response
from .exceptions import ConnectTimeout, ReadTimeout
class HTTP2Connection(Client):
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
origin: Origin,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
on_release: typing.Callable = None,
):
self.origin = origin
self.reader = reader
self.writer = writer
self.timeout = timeout
self.on_release = on_release
self.h2_state = h2.connection.H2Connection()
self.events = [] # type: typing.List[h2.events.Event]
@property
def is_closed(self) -> bool:
return False
async def send(
self,
request: Request,
*,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None
) -> Response:
if timeout is None:
timeout = self.timeout
#  Start sending the request.
await self._initiate_connection()
await self._send_headers(request)
# Send the request body.
if request.body:
await self._send_data(request.body)
# Finalize sending the request.
await self._end_stream()
# Start getting the response.
while True:
event = await self._receive_event(timeout)
if isinstance(event, h2.events.ResponseReceived):
break
status_code = 200
headers = []
for k, v in event.headers:
if k == b":status":
status_code = int(v.decode())
elif not k.startswith(b":"):
headers.append((k, v))
body = self._body_iter(timeout)
return Response(
status_code=status_code,
protocol="HTTP/2",
headers=headers,
body=body,
on_close=self._release,
)
async def _initiate_connection(self) -> None:
self.h2_state.initiate_connection()
data_to_send = self.h2_state.data_to_send()
self.writer.write(data_to_send)
async def _send_headers(self, request: Request) -> None:
headers = [
(b":method", request.method.encode()),
(b":authority", request.url.hostname.encode()),
(b":scheme", request.url.scheme.encode()),
(b":path", request.url.full_path.encode()),
] + request.headers
self.h2_state.send_headers(1, headers)
data_to_send = self.h2_state.data_to_send()
self.writer.write(data_to_send)
async def _send_data(self, data: bytes) -> None:
self.h2_state.send_data(1, data)
data_to_send = self.h2_state.data_to_send()
self.writer.write(data_to_send)
async def _end_stream(self) -> None:
self.h2_state.end_stream(1)
data_to_send = self.h2_state.data_to_send()
self.writer.write(data_to_send)
async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]:
while True:
event = await self._receive_event(timeout)
if isinstance(event, h2.events.DataReceived):
yield event.data
elif isinstance(event, h2.events.StreamEnded):
break
async def _receive_event(self, timeout: TimeoutConfig) -> h2.events.Event:
while not self.events:
try:
data = await asyncio.wait_for(
self.reader.read(2048), timeout.read_timeout
)
except asyncio.TimeoutError:
raise ReadTimeout()
events = self.h2_state.receive_data(data)
self.events.extend(events)
data_to_send = self.h2_state.data_to_send()
if data_to_send:
self.writer.write(data_to_send)
return self.events.pop(0)
async def _release(self) -> None:
# if (
# self.h11_state.our_state is h11.DONE
# and self.h11_state.their_state is h11.DONE
# ):
# self.h11_state.start_next_cycle()
# else:
# await self.close()
if self.on_release is not None:
await self.on_release(self)
async def close(self) -> None:
# event = h11.ConnectionClosed()
# try:
# # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
# self.h11_state.send(event)
# except h11.ProtocolError:
# # If we're in some other state then it's a premature close,
# # and we'll end up in h11.ERROR.
# pass
if self.writer is not None:
self.writer.close()

View File

@ -1,5 +1,6 @@
certifi
h11
h2
# Optional
brotlipy

View File

@ -5,7 +5,7 @@ import httpcore
@pytest.mark.asyncio
async def test_get(server):
http = httpcore.HTTP11Connection(origin="http://127.0.0.1:8000/")
http = httpcore.HTTPConnection(origin="http://127.0.0.1:8000/")
response = await http.request("GET", "http://127.0.0.1:8000/")
assert response.status_code == 200
assert response.body == b"Hello, world!"
@ -13,7 +13,7 @@ async def test_get(server):
@pytest.mark.asyncio
async def test_post(server):
http = httpcore.HTTP11Connection(origin="http://127.0.0.1:8000/")
http = httpcore.HTTPConnection(origin="http://127.0.0.1:8000/")
response = await http.request(
"POST", "http://127.0.0.1:8000/", body=b"Hello, world!"
)

View File

@ -73,12 +73,12 @@ def test_url():
request = httpcore.Request("GET", "http://example.org")
assert request.url.scheme == "http"
assert request.url.port == 80
assert request.url.target == "/"
assert request.url.full_path == "/"
request = httpcore.Request("GET", "https://example.org/abc?foo=bar")
assert request.url.scheme == "https"
assert request.url.port == 443
assert request.url.target == "/abc?foo=bar"
assert request.url.full_path == "/abc?foo=bar"
def test_invalid_urls():