First pass at HTTP/2 support
This commit is contained in:
parent
91a2a1b896
commit
53f3dc4a66
@ -1,4 +1,5 @@
|
||||
from .config import PoolLimits, SSLConfig, TimeoutConfig
|
||||
from .connection import HTTPConnection
|
||||
from .connectionpool import ConnectionPool
|
||||
from .datastructures import URL, Origin, Request, Response
|
||||
from .exceptions import (
|
||||
@ -10,6 +11,7 @@ from .exceptions import (
|
||||
StreamConsumed,
|
||||
Timeout,
|
||||
)
|
||||
from .http2 import HTTP2Connection
|
||||
from .http11 import HTTP11Connection
|
||||
from .sync import SyncClient, SyncConnectionPool
|
||||
|
||||
|
||||
@ -73,7 +73,23 @@ class SSLConfig:
|
||||
"invalid path: {}".format(self.verify)
|
||||
)
|
||||
|
||||
context = ssl.create_default_context()
|
||||
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
|
||||
|
||||
context.options |= ssl.OP_NO_SSLv2
|
||||
context.options |= ssl.OP_NO_SSLv3
|
||||
context.options |= ssl.OP_NO_COMPRESSION
|
||||
|
||||
# RFC 7540 Section 9.2.2: "deployments of HTTP/2 that use TLS 1.2 MUST
|
||||
# support TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256". In practice, the
|
||||
# blacklist defined in this section allows only the AES GCM and ChaCha20
|
||||
# cipher suites with ephemeral key negotiation.
|
||||
context.set_ciphers("ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20")
|
||||
|
||||
if ssl.HAS_ALPN:
|
||||
context.set_alpn_protocols(["h2", "http/1.1"])
|
||||
if ssl.HAS_NPN:
|
||||
context.set_npn_protocols(["h2", "http/1.1"])
|
||||
|
||||
if os.path.isfile(ca_bundle_path):
|
||||
context.load_verify_locations(cafile=ca_bundle_path)
|
||||
elif os.path.isdir(ca_bundle_path):
|
||||
|
||||
106
httpcore/connection.py
Normal file
106
httpcore/connection.py
Normal file
@ -0,0 +1,106 @@
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
import h2.connection
|
||||
import h11
|
||||
|
||||
from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
|
||||
from .datastructures import Client, Origin, Request, Response
|
||||
from .exceptions import ConnectTimeout
|
||||
from .http2 import HTTP2Connection
|
||||
from .http11 import HTTP11Connection
|
||||
|
||||
|
||||
class HTTPConnection(Client):
|
||||
def __init__(
|
||||
self,
|
||||
origin: typing.Union[str, Origin],
|
||||
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
on_release: typing.Callable = None,
|
||||
):
|
||||
self.origin = Origin(origin) if isinstance(origin, str) else origin
|
||||
self.ssl = ssl
|
||||
self.timeout = timeout
|
||||
self.on_release = on_release
|
||||
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
|
||||
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
) -> Response:
|
||||
if self.h11_connection is None and self.h2_connection is None:
|
||||
if ssl is None:
|
||||
ssl = self.ssl
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
reader, writer, protocol = await self.connect(ssl, timeout)
|
||||
if protocol == "h2":
|
||||
self.h2_connection = HTTP2Connection(
|
||||
reader,
|
||||
writer,
|
||||
origin=self.origin,
|
||||
timeout=self.timeout,
|
||||
on_release=self.on_release,
|
||||
)
|
||||
else:
|
||||
self.h11_connection = HTTP11Connection(
|
||||
reader,
|
||||
writer,
|
||||
origin=self.origin,
|
||||
timeout=self.timeout,
|
||||
on_release=self.on_release,
|
||||
)
|
||||
|
||||
if self.h2_connection is not None:
|
||||
response = await self.h2_connection.send(request, ssl=ssl, timeout=timeout)
|
||||
else:
|
||||
assert self.h11_connection is not None
|
||||
response = await self.h11_connection.send(request, ssl=ssl, timeout=timeout)
|
||||
|
||||
return response
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.h2_connection is not None:
|
||||
await self.h2_connection.close()
|
||||
else:
|
||||
assert self.h11_connection is not None
|
||||
await self.h11_connection.close()
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
if self.h2_connection is not None:
|
||||
return self.h2_connection.is_closed
|
||||
else:
|
||||
assert self.h11_connection is not None
|
||||
return self.h11_connection.is_closed
|
||||
|
||||
async def connect(
|
||||
self, ssl: SSLConfig, timeout: TimeoutConfig
|
||||
) -> typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter, str]:
|
||||
hostname = self.origin.hostname
|
||||
port = self.origin.port
|
||||
ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
|
||||
|
||||
try:
|
||||
reader, writer = await asyncio.wait_for( # type: ignore
|
||||
asyncio.open_connection(hostname, port, ssl=ssl_context),
|
||||
timeout.connect_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise ConnectTimeout()
|
||||
|
||||
ssl_object = writer.get_extra_info("ssl_object")
|
||||
if ssl_object is None:
|
||||
protocol = "http/1.1"
|
||||
else:
|
||||
protocol = ssl_object.selected_alpn_protocol()
|
||||
if protocol is None:
|
||||
protocol = ssl_object.selected_npn_protocol()
|
||||
|
||||
return (reader, writer, protocol)
|
||||
@ -10,9 +10,9 @@ from .config import (
|
||||
SSLConfig,
|
||||
TimeoutConfig,
|
||||
)
|
||||
from .connection import HTTPConnection
|
||||
from .datastructures import Client, Origin, Request, Response
|
||||
from .exceptions import PoolTimeout
|
||||
from .http11 import HTTP11Connection
|
||||
|
||||
|
||||
class ConnectionPool(Client):
|
||||
@ -31,7 +31,7 @@ class ConnectionPool(Client):
|
||||
self.num_keepalive_connections = 0
|
||||
self._keepalive_connections = (
|
||||
{}
|
||||
) # type: typing.Dict[Origin, typing.List[HTTP11Connection]]
|
||||
) # type: typing.Dict[Origin, typing.List[HTTPConnection]]
|
||||
self._max_connections = ConnectionSemaphore(
|
||||
max_connections=self.limits.hard_limit
|
||||
)
|
||||
@ -53,7 +53,7 @@ class ConnectionPool(Client):
|
||||
|
||||
async def acquire_connection(
|
||||
self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None
|
||||
) -> HTTP11Connection:
|
||||
) -> HTTPConnection:
|
||||
try:
|
||||
connection = self._keepalive_connections[origin].pop()
|
||||
if not self._keepalive_connections[origin]:
|
||||
@ -71,7 +71,7 @@ class ConnectionPool(Client):
|
||||
await asyncio.wait_for(self._max_connections.acquire(), pool_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise PoolTimeout()
|
||||
connection = HTTP11Connection(
|
||||
connection = HTTPConnection(
|
||||
origin,
|
||||
ssl=self.ssl,
|
||||
timeout=self.timeout,
|
||||
@ -81,7 +81,7 @@ class ConnectionPool(Client):
|
||||
|
||||
return connection
|
||||
|
||||
async def release_connection(self, connection: HTTP11Connection) -> None:
|
||||
async def release_connection(self, connection: HTTPConnection) -> None:
|
||||
if connection.is_closed:
|
||||
self._max_connections.release()
|
||||
self.num_active_connections -= 1
|
||||
|
||||
@ -52,7 +52,7 @@ class URL:
|
||||
return port
|
||||
|
||||
@property
|
||||
def target(self) -> str:
|
||||
def full_path(self) -> str:
|
||||
path = self.path or "/"
|
||||
query = self.query
|
||||
if query:
|
||||
@ -138,10 +138,11 @@ class Request:
|
||||
return headers
|
||||
|
||||
async def stream(self) -> typing.AsyncIterator[bytes]:
|
||||
assert self.is_streaming
|
||||
|
||||
async for part in self.body_aiter:
|
||||
yield part
|
||||
if self.is_streaming:
|
||||
async for part in self.body_aiter:
|
||||
yield part
|
||||
elif self.body:
|
||||
yield self.body
|
||||
|
||||
|
||||
class Response:
|
||||
@ -150,6 +151,7 @@ class Response:
|
||||
status_code: int,
|
||||
*,
|
||||
reason: typing.Optional[str] = None,
|
||||
protocol: typing.Optional[str] = None,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
on_close: typing.Callable = None,
|
||||
@ -162,6 +164,7 @@ class Response:
|
||||
self.reason = ""
|
||||
else:
|
||||
self.reason = reason
|
||||
self.protocol = protocol
|
||||
self.headers = list(headers)
|
||||
self.on_close = on_close
|
||||
self.is_closed = False
|
||||
|
||||
@ -20,55 +20,43 @@ H11Event = typing.Union[
|
||||
class HTTP11Connection(Client):
|
||||
def __init__(
|
||||
self,
|
||||
origin: typing.Union[str, Origin],
|
||||
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
origin: Origin,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
on_release: typing.Callable = None,
|
||||
):
|
||||
self.origin = Origin(origin) if isinstance(origin, str) else origin
|
||||
self.ssl = ssl
|
||||
self.origin = origin
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self.timeout = timeout
|
||||
self.on_release = on_release
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
self._h11_state = h11.Connection(our_role=h11.CLIENT)
|
||||
self.h11_state = h11.Connection(our_role=h11.CLIENT)
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self._h11_state.our_state in (h11.CLOSED, h11.ERROR)
|
||||
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None
|
||||
) -> Response:
|
||||
assert request.url.origin == self.origin
|
||||
|
||||
if ssl is None:
|
||||
ssl = self.ssl
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
# Make the connection
|
||||
if self._reader is None:
|
||||
await self._connect(ssl, timeout)
|
||||
|
||||
# Start sending the request.
|
||||
method = request.method.encode()
|
||||
target = request.url.target
|
||||
target = request.url.full_path
|
||||
headers = request.headers
|
||||
event = h11.Request(method=method, target=target, headers=headers)
|
||||
await self._send_event(event)
|
||||
|
||||
# Send the request body.
|
||||
if request.is_streaming:
|
||||
async for data in request.stream():
|
||||
event = h11.Data(data=data)
|
||||
await self._send_event(event)
|
||||
elif request.body:
|
||||
event = h11.Data(data=request.body)
|
||||
async for data in request.stream():
|
||||
event = h11.Data(data=data)
|
||||
await self._send_event(event)
|
||||
|
||||
# Finalize sending the request.
|
||||
@ -79,32 +67,22 @@ class HTTP11Connection(Client):
|
||||
event = await self._receive_event(timeout)
|
||||
if isinstance(event, h11.InformationalResponse):
|
||||
event = await self._receive_event(timeout)
|
||||
|
||||
assert isinstance(event, h11.Response)
|
||||
reason = event.reason.decode("latin1")
|
||||
status_code = event.status_code
|
||||
headers = event.headers
|
||||
body = self._body_iter(timeout)
|
||||
|
||||
return Response(
|
||||
status_code=status_code,
|
||||
reason=reason,
|
||||
protocol="HTTP/1.1",
|
||||
headers=headers,
|
||||
body=body,
|
||||
on_close=self._release,
|
||||
)
|
||||
|
||||
async def _connect(self, ssl: SSLConfig, timeout: TimeoutConfig) -> None:
|
||||
hostname = self.origin.hostname
|
||||
port = self.origin.port
|
||||
ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
|
||||
|
||||
try:
|
||||
self._reader, self._writer = await asyncio.wait_for( # type: ignore
|
||||
asyncio.open_connection(hostname, port, ssl=ssl_context),
|
||||
timeout.connect_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise ConnectTimeout()
|
||||
|
||||
async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]:
|
||||
event = await self._receive_event(timeout)
|
||||
while isinstance(event, h11.Data):
|
||||
@ -113,36 +91,30 @@ class HTTP11Connection(Client):
|
||||
assert isinstance(event, h11.EndOfMessage)
|
||||
|
||||
async def _send_event(self, event: H11Event) -> None:
|
||||
assert self._writer is not None
|
||||
|
||||
data = self._h11_state.send(event)
|
||||
self._writer.write(data)
|
||||
data = self.h11_state.send(event)
|
||||
self.writer.write(data)
|
||||
|
||||
async def _receive_event(self, timeout: TimeoutConfig) -> H11Event:
|
||||
assert self._reader is not None
|
||||
|
||||
event = self._h11_state.next_event()
|
||||
event = self.h11_state.next_event()
|
||||
|
||||
while event is h11.NEED_DATA:
|
||||
try:
|
||||
data = await asyncio.wait_for(
|
||||
self._reader.read(2048), timeout.read_timeout
|
||||
self.reader.read(2048), timeout.read_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise ReadTimeout()
|
||||
self._h11_state.receive_data(data)
|
||||
event = self._h11_state.next_event()
|
||||
self.h11_state.receive_data(data)
|
||||
event = self.h11_state.next_event()
|
||||
|
||||
return event
|
||||
|
||||
async def _release(self) -> None:
|
||||
assert self._writer is not None
|
||||
|
||||
if (
|
||||
self._h11_state.our_state is h11.DONE
|
||||
and self._h11_state.their_state is h11.DONE
|
||||
self.h11_state.our_state is h11.DONE
|
||||
and self.h11_state.their_state is h11.DONE
|
||||
):
|
||||
self._h11_state.start_next_cycle()
|
||||
self.h11_state.start_next_cycle()
|
||||
else:
|
||||
await self.close()
|
||||
|
||||
@ -153,11 +125,11 @@ class HTTP11Connection(Client):
|
||||
event = h11.ConnectionClosed()
|
||||
try:
|
||||
# If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
|
||||
self._h11_state.send(event)
|
||||
self.h11_state.send(event)
|
||||
except h11.ProtocolError:
|
||||
# If we're in some other state then it's a premature close,
|
||||
# and we'll end up in h11.ERROR.
|
||||
pass
|
||||
|
||||
if self._writer is not None:
|
||||
self._writer.close()
|
||||
if self.writer is not None:
|
||||
self.writer.close()
|
||||
|
||||
152
httpcore/http2.py
Normal file
152
httpcore/http2.py
Normal file
@ -0,0 +1,152 @@
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
import h2.connection
|
||||
import h2.events
|
||||
|
||||
from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
|
||||
from .datastructures import Client, Origin, Request, Response
|
||||
from .exceptions import ConnectTimeout, ReadTimeout
|
||||
|
||||
|
||||
class HTTP2Connection(Client):
|
||||
def __init__(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
origin: Origin,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
on_release: typing.Callable = None,
|
||||
):
|
||||
self.origin = origin
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self.timeout = timeout
|
||||
self.on_release = on_release
|
||||
self.h2_state = h2.connection.H2Connection()
|
||||
self.events = [] # type: typing.List[h2.events.Event]
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return False
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None
|
||||
) -> Response:
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
# Start sending the request.
|
||||
await self._initiate_connection()
|
||||
await self._send_headers(request)
|
||||
|
||||
# Send the request body.
|
||||
if request.body:
|
||||
await self._send_data(request.body)
|
||||
|
||||
# Finalize sending the request.
|
||||
await self._end_stream()
|
||||
|
||||
# Start getting the response.
|
||||
while True:
|
||||
event = await self._receive_event(timeout)
|
||||
if isinstance(event, h2.events.ResponseReceived):
|
||||
break
|
||||
|
||||
status_code = 200
|
||||
headers = []
|
||||
for k, v in event.headers:
|
||||
if k == b":status":
|
||||
status_code = int(v.decode())
|
||||
elif not k.startswith(b":"):
|
||||
headers.append((k, v))
|
||||
|
||||
body = self._body_iter(timeout)
|
||||
return Response(
|
||||
status_code=status_code,
|
||||
protocol="HTTP/2",
|
||||
headers=headers,
|
||||
body=body,
|
||||
on_close=self._release,
|
||||
)
|
||||
|
||||
async def _initiate_connection(self) -> None:
|
||||
self.h2_state.initiate_connection()
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
self.writer.write(data_to_send)
|
||||
|
||||
async def _send_headers(self, request: Request) -> None:
|
||||
headers = [
|
||||
(b":method", request.method.encode()),
|
||||
(b":authority", request.url.hostname.encode()),
|
||||
(b":scheme", request.url.scheme.encode()),
|
||||
(b":path", request.url.full_path.encode()),
|
||||
] + request.headers
|
||||
self.h2_state.send_headers(1, headers)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
self.writer.write(data_to_send)
|
||||
|
||||
async def _send_data(self, data: bytes) -> None:
|
||||
self.h2_state.send_data(1, data)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
self.writer.write(data_to_send)
|
||||
|
||||
async def _end_stream(self) -> None:
|
||||
self.h2_state.end_stream(1)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
self.writer.write(data_to_send)
|
||||
|
||||
async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]:
|
||||
while True:
|
||||
event = await self._receive_event(timeout)
|
||||
if isinstance(event, h2.events.DataReceived):
|
||||
yield event.data
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
break
|
||||
|
||||
async def _receive_event(self, timeout: TimeoutConfig) -> h2.events.Event:
|
||||
while not self.events:
|
||||
try:
|
||||
data = await asyncio.wait_for(
|
||||
self.reader.read(2048), timeout.read_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise ReadTimeout()
|
||||
|
||||
events = self.h2_state.receive_data(data)
|
||||
self.events.extend(events)
|
||||
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
if data_to_send:
|
||||
self.writer.write(data_to_send)
|
||||
|
||||
return self.events.pop(0)
|
||||
|
||||
async def _release(self) -> None:
|
||||
# if (
|
||||
# self.h11_state.our_state is h11.DONE
|
||||
# and self.h11_state.their_state is h11.DONE
|
||||
# ):
|
||||
# self.h11_state.start_next_cycle()
|
||||
# else:
|
||||
# await self.close()
|
||||
|
||||
if self.on_release is not None:
|
||||
await self.on_release(self)
|
||||
|
||||
async def close(self) -> None:
|
||||
# event = h11.ConnectionClosed()
|
||||
# try:
|
||||
# # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
|
||||
# self.h11_state.send(event)
|
||||
# except h11.ProtocolError:
|
||||
# # If we're in some other state then it's a premature close,
|
||||
# # and we'll end up in h11.ERROR.
|
||||
# pass
|
||||
|
||||
if self.writer is not None:
|
||||
self.writer.close()
|
||||
@ -1,5 +1,6 @@
|
||||
certifi
|
||||
h11
|
||||
h2
|
||||
|
||||
# Optional
|
||||
brotlipy
|
||||
|
||||
@ -5,7 +5,7 @@ import httpcore
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get(server):
|
||||
http = httpcore.HTTP11Connection(origin="http://127.0.0.1:8000/")
|
||||
http = httpcore.HTTPConnection(origin="http://127.0.0.1:8000/")
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"Hello, world!"
|
||||
@ -13,7 +13,7 @@ async def test_get(server):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post(server):
|
||||
http = httpcore.HTTP11Connection(origin="http://127.0.0.1:8000/")
|
||||
http = httpcore.HTTPConnection(origin="http://127.0.0.1:8000/")
|
||||
response = await http.request(
|
||||
"POST", "http://127.0.0.1:8000/", body=b"Hello, world!"
|
||||
)
|
||||
|
||||
@ -73,12 +73,12 @@ def test_url():
|
||||
request = httpcore.Request("GET", "http://example.org")
|
||||
assert request.url.scheme == "http"
|
||||
assert request.url.port == 80
|
||||
assert request.url.target == "/"
|
||||
assert request.url.full_path == "/"
|
||||
|
||||
request = httpcore.Request("GET", "https://example.org/abc?foo=bar")
|
||||
assert request.url.scheme == "https"
|
||||
assert request.url.port == 443
|
||||
assert request.url.target == "/abc?foo=bar"
|
||||
assert request.url.full_path == "/abc?foo=bar"
|
||||
|
||||
|
||||
def test_invalid_urls():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user