Support alternate concurrency backends from client & rejig tests

This commit is contained in:
Tom Christie 2019-05-14 14:11:13 +01:00
parent 2694c48492
commit 3acc31621c
5 changed files with 141 additions and 106 deletions

View File

@ -1,4 +1,5 @@
from .client import AsyncClient, Client
from .concurrency import AsyncioBackend
from .config import PoolLimits, SSLConfig, TimeoutConfig
from .dispatch.connection import HTTPConnection
from .dispatch.connection_pool import ConnectionPool
@ -19,7 +20,7 @@ from .exceptions import (
Timeout,
TooManyRedirects,
)
from .interfaces import BaseReader, BaseWriter, Dispatcher, Protocol
from .interfaces import BaseReader, BaseWriter, ConcurrencyBackend, Dispatcher, Protocol
from .models import URL, Headers, Origin, QueryParams, Request, Response
from .status_codes import codes

View File

@ -13,7 +13,7 @@ from .config import (
)
from .dispatch.connection_pool import ConnectionPool
from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
from .interfaces import Dispatcher
from .interfaces import ConcurrencyBackend, Dispatcher
from .models import (
URL,
Headers,
@ -36,9 +36,12 @@ class AsyncClient:
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
dispatch: Dispatcher = None,
backend: ConcurrencyBackend = None,
):
if dispatch is None:
dispatch = ConnectionPool(ssl=ssl, timeout=timeout, pool_limits=pool_limits)
dispatch = ConnectionPool(
ssl=ssl, timeout=timeout, pool_limits=pool_limits, backend=backend
)
self.max_redirects = max_redirects
self.dispatch = dispatch
@ -377,6 +380,7 @@ class Client:
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
dispatch: Dispatcher = None,
backend: ConcurrencyBackend = None,
) -> None:
self._client = AsyncClient(
ssl=ssl,
@ -384,6 +388,7 @@ class Client:
pool_limits=pool_limits,
max_redirects=max_redirects,
dispatch=dispatch,
backend=backend,
)
self._loop = asyncio.new_event_loop()

View File

View File

@ -1,105 +1,36 @@
import json
import h2.config
import h2.connection
import h2.events
import pytest
from httpcore import BaseReader, BaseWriter, HTTP2Connection, Request
from httpcore import Client, Response
from .utils import MockHTTP2Backend
class MockServer(BaseReader, BaseWriter):
"""
This class exposes Reader and Writer style interfaces
"""
def __init__(self):
config = h2.config.H2Configuration(client_side=False)
self.conn = h2.connection.H2Connection(config=config)
self.buffer = b""
self.requests = {}
# BaseReader interface
async def read(self, n, timeout) -> bytes:
send, self.buffer = self.buffer[:n], self.buffer[n:]
return send
# BaseWriter interface
def write_no_block(self, data: bytes) -> None:
events = self.conn.receive_data(data)
self.buffer += self.conn.data_to_send()
for event in events:
if isinstance(event, h2.events.RequestReceived):
self.request_received(event.headers, event.stream_id)
elif isinstance(event, h2.events.DataReceived):
self.receive_data(event.data, event.stream_id)
elif isinstance(event, h2.events.StreamEnded):
self.stream_complete(event.stream_id)
async def write(self, data: bytes, timeout) -> None:
self.write_no_block(data)
async def close(self) -> None:
pass
# Server implementation
def request_received(self, headers, stream_id):
if stream_id not in self.requests:
self.requests[stream_id] = []
self.requests[stream_id].append({"headers": headers, "data": b""})
def receive_data(self, data, stream_id):
self.requests[stream_id][-1]["data"] += data
def stream_complete(self, stream_id):
request = self.requests[stream_id].pop(0)
if not self.requests[stream_id]:
del self.requests[stream_id]
request_headers = dict(request["headers"])
request_data = request["data"]
response_body = json.dumps(
{
"method": request_headers[b":method"].decode(),
"path": request_headers[b":path"].decode(),
"body": request_data.decode(),
}
).encode()
response_headers = (
(b":status", b"200"),
(b"content-length", str(len(response_body)).encode()),
)
self.conn.send_headers(stream_id, response_headers)
self.conn.send_data(stream_id, response_body, end_stream=True)
self.buffer += self.conn.data_to_send()
def app(request):
content = json.dumps({
"method": request.method,
"path": request.url.path,
"body": request.content.decode(),
}).encode()
headers = {'Content-Length': str(len(content))}
return Response(200, headers=headers, content=content)
@pytest.mark.asyncio
async def test_http2_get_request():
server = MockServer()
conn = HTTP2Connection(reader=server, writer=server)
request = Request("GET", "http://example.org")
request.prepare()
def test_http2_get_request():
backend = MockHTTP2Backend(app=app)
response = await conn.send(request)
with Client(backend=backend) as client:
response = client.get("http://example.org")
assert response.status_code == 200
assert json.loads(response.content) == {"method": "GET", "path": "/", "body": ""}
@pytest.mark.asyncio
async def test_http2_post_request():
server = MockServer()
conn = HTTP2Connection(reader=server, writer=server)
request = Request("POST", "http://example.org", data=b"<data>")
request.prepare()
def test_http2_post_request():
backend = MockHTTP2Backend(app=app)
response = await conn.send(request)
with Client(backend=backend) as client:
response = client.post("http://example.org", data=b"<data>")
assert response.status_code == 200
assert json.loads(response.content) == {
@ -109,21 +40,13 @@ async def test_http2_post_request():
}
@pytest.mark.asyncio
async def test_http2_multiple_requests():
server = MockServer()
conn = HTTP2Connection(reader=server, writer=server)
request_1 = Request("GET", "http://example.org/1")
request_2 = Request("GET", "http://example.org/2")
request_3 = Request("GET", "http://example.org/3")
def test_http2_multiple_requests():
backend = MockHTTP2Backend(app=app)
request_1.prepare()
request_2.prepare()
request_3.prepare()
response_1 = await conn.send(request_1)
response_2 = await conn.send(request_2)
response_3 = await conn.send(request_3)
with Client(backend=backend) as client:
response_1 = client.get("http://example.org/1")
response_2 = client.get("http://example.org/2")
response_3 = client.get("http://example.org/3")
assert response_1.status_code == 200
assert json.loads(response_1.content) == {"method": "GET", "path": "/1", "body": ""}
@ -133,5 +56,3 @@ async def test_http2_multiple_requests():
assert response_3.status_code == 200
assert json.loads(response_3.content) == {"method": "GET", "path": "/3", "body": ""}
await conn.close()

108
tests/dispatch/utils.py Normal file
View File

@ -0,0 +1,108 @@
import ssl
import typing
import h2.config
import h2.connection
import h2.events
from httpcore import AsyncioBackend, BaseReader, BaseWriter, Protocol, Request, TimeoutConfig
class MockHTTP2Backend(AsyncioBackend):
def __init__(self, app):
self.app = app
async def connect(
self,
hostname: str,
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
server = MockHTTP2Server(self.app)
return (server, server, Protocol.HTTP_2)
class MockHTTP2Server(BaseReader, BaseWriter):
"""
This class exposes Reader and Writer style interfaces.
"""
def __init__(self, app):
config = h2.config.H2Configuration(client_side=False)
self.conn = h2.connection.H2Connection(config=config)
self.app = app
self.buffer = b""
self.requests = {}
# BaseReader interface
async def read(self, n, timeout) -> bytes:
send, self.buffer = self.buffer[:n], self.buffer[n:]
return send
# BaseWriter interface
def write_no_block(self, data: bytes) -> None:
events = self.conn.receive_data(data)
self.buffer += self.conn.data_to_send()
for event in events:
if isinstance(event, h2.events.RequestReceived):
self.request_received(event.headers, event.stream_id)
elif isinstance(event, h2.events.DataReceived):
self.receive_data(event.data, event.stream_id)
elif isinstance(event, h2.events.StreamEnded):
self.stream_complete(event.stream_id)
async def write(self, data: bytes, timeout) -> None:
self.write_no_block(data)
async def close(self) -> None:
pass
# Server implementation
def request_received(self, headers, stream_id):
"""
Handler for when the initial part of the HTTP request is received.
"""
if stream_id not in self.requests:
self.requests[stream_id] = []
self.requests[stream_id].append({"headers": headers, "data": b""})
def receive_data(self, data, stream_id):
"""
Handler for when a data part of the HTTP request is received.
"""
self.requests[stream_id][-1]["data"] += data
def stream_complete(self, stream_id):
"""
Handler for when the HTTP request is completed.
"""
request = self.requests[stream_id].pop(0)
if not self.requests[stream_id]:
del self.requests[stream_id]
headers_dict = dict(request["headers"])
method = headers_dict[b":method"].decode("ascii")
url = "%s://%s%s" % (
headers_dict[b":scheme"].decode("ascii"),
headers_dict[b":authority"].decode("ascii"),
headers_dict[b":path"].decode("ascii"),
)
headers = [(k, v) for k, v in request["headers"] if not k.startswith(b":")]
data = request["data"]
# Call out to the app.
request = Request(method, url, headers=headers, data=data)
response = self.app(request)
# Write the response to the buffer.
status_code_bytes = str(int(response.status_code)).encode("ascii")
response_headers = [(b":status", status_code_bytes)] + response.headers.raw
self.conn.send_headers(stream_id, response_headers)
self.conn.send_data(stream_id, response.content, end_stream=True)
self.buffer += self.conn.data_to_send()