Support alternate concurrency backends from client & rejig tests
This commit is contained in:
parent
2694c48492
commit
3acc31621c
@ -1,4 +1,5 @@
|
||||
from .client import AsyncClient, Client
|
||||
from .concurrency import AsyncioBackend
|
||||
from .config import PoolLimits, SSLConfig, TimeoutConfig
|
||||
from .dispatch.connection import HTTPConnection
|
||||
from .dispatch.connection_pool import ConnectionPool
|
||||
@ -19,7 +20,7 @@ from .exceptions import (
|
||||
Timeout,
|
||||
TooManyRedirects,
|
||||
)
|
||||
from .interfaces import BaseReader, BaseWriter, Dispatcher, Protocol
|
||||
from .interfaces import BaseReader, BaseWriter, ConcurrencyBackend, Dispatcher, Protocol
|
||||
from .models import URL, Headers, Origin, QueryParams, Request, Response
|
||||
from .status_codes import codes
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ from .config import (
|
||||
)
|
||||
from .dispatch.connection_pool import ConnectionPool
|
||||
from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
|
||||
from .interfaces import Dispatcher
|
||||
from .interfaces import ConcurrencyBackend, Dispatcher
|
||||
from .models import (
|
||||
URL,
|
||||
Headers,
|
||||
@ -36,9 +36,12 @@ class AsyncClient:
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
max_redirects: int = DEFAULT_MAX_REDIRECTS,
|
||||
dispatch: Dispatcher = None,
|
||||
backend: ConcurrencyBackend = None,
|
||||
):
|
||||
if dispatch is None:
|
||||
dispatch = ConnectionPool(ssl=ssl, timeout=timeout, pool_limits=pool_limits)
|
||||
dispatch = ConnectionPool(
|
||||
ssl=ssl, timeout=timeout, pool_limits=pool_limits, backend=backend
|
||||
)
|
||||
|
||||
self.max_redirects = max_redirects
|
||||
self.dispatch = dispatch
|
||||
@ -377,6 +380,7 @@ class Client:
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
max_redirects: int = DEFAULT_MAX_REDIRECTS,
|
||||
dispatch: Dispatcher = None,
|
||||
backend: ConcurrencyBackend = None,
|
||||
) -> None:
|
||||
self._client = AsyncClient(
|
||||
ssl=ssl,
|
||||
@ -384,6 +388,7 @@ class Client:
|
||||
pool_limits=pool_limits,
|
||||
max_redirects=max_redirects,
|
||||
dispatch=dispatch,
|
||||
backend=backend,
|
||||
)
|
||||
self._loop = asyncio.new_event_loop()
|
||||
|
||||
|
||||
0
tests/dispatch/__init__.py
Normal file
0
tests/dispatch/__init__.py
Normal file
@ -1,105 +1,36 @@
|
||||
import json
|
||||
|
||||
import h2.config
|
||||
import h2.connection
|
||||
import h2.events
|
||||
import pytest
|
||||
|
||||
from httpcore import BaseReader, BaseWriter, HTTP2Connection, Request
|
||||
from httpcore import Client, Response
|
||||
from .utils import MockHTTP2Backend
|
||||
|
||||
|
||||
class MockServer(BaseReader, BaseWriter):
|
||||
"""
|
||||
This class exposes Reader and Writer style interfaces
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
config = h2.config.H2Configuration(client_side=False)
|
||||
self.conn = h2.connection.H2Connection(config=config)
|
||||
self.buffer = b""
|
||||
self.requests = {}
|
||||
|
||||
# BaseReader interface
|
||||
|
||||
async def read(self, n, timeout) -> bytes:
|
||||
send, self.buffer = self.buffer[:n], self.buffer[n:]
|
||||
return send
|
||||
|
||||
# BaseWriter interface
|
||||
|
||||
def write_no_block(self, data: bytes) -> None:
|
||||
events = self.conn.receive_data(data)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.RequestReceived):
|
||||
self.request_received(event.headers, event.stream_id)
|
||||
elif isinstance(event, h2.events.DataReceived):
|
||||
self.receive_data(event.data, event.stream_id)
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
self.stream_complete(event.stream_id)
|
||||
|
||||
async def write(self, data: bytes, timeout) -> None:
|
||||
self.write_no_block(data)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
# Server implementation
|
||||
|
||||
def request_received(self, headers, stream_id):
|
||||
if stream_id not in self.requests:
|
||||
self.requests[stream_id] = []
|
||||
self.requests[stream_id].append({"headers": headers, "data": b""})
|
||||
|
||||
def receive_data(self, data, stream_id):
|
||||
self.requests[stream_id][-1]["data"] += data
|
||||
|
||||
def stream_complete(self, stream_id):
|
||||
request = self.requests[stream_id].pop(0)
|
||||
if not self.requests[stream_id]:
|
||||
del self.requests[stream_id]
|
||||
|
||||
request_headers = dict(request["headers"])
|
||||
request_data = request["data"]
|
||||
|
||||
response_body = json.dumps(
|
||||
{
|
||||
"method": request_headers[b":method"].decode(),
|
||||
"path": request_headers[b":path"].decode(),
|
||||
"body": request_data.decode(),
|
||||
}
|
||||
).encode()
|
||||
|
||||
response_headers = (
|
||||
(b":status", b"200"),
|
||||
(b"content-length", str(len(response_body)).encode()),
|
||||
)
|
||||
self.conn.send_headers(stream_id, response_headers)
|
||||
self.conn.send_data(stream_id, response_body, end_stream=True)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
def app(request):
|
||||
content = json.dumps({
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"body": request.content.decode(),
|
||||
}).encode()
|
||||
headers = {'Content-Length': str(len(content))}
|
||||
return Response(200, headers=headers, content=content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_get_request():
|
||||
server = MockServer()
|
||||
conn = HTTP2Connection(reader=server, writer=server)
|
||||
request = Request("GET", "http://example.org")
|
||||
request.prepare()
|
||||
def test_http2_get_request():
|
||||
backend = MockHTTP2Backend(app=app)
|
||||
|
||||
response = await conn.send(request)
|
||||
with Client(backend=backend) as client:
|
||||
response = client.get("http://example.org")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.content) == {"method": "GET", "path": "/", "body": ""}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_post_request():
|
||||
server = MockServer()
|
||||
conn = HTTP2Connection(reader=server, writer=server)
|
||||
request = Request("POST", "http://example.org", data=b"<data>")
|
||||
request.prepare()
|
||||
def test_http2_post_request():
|
||||
backend = MockHTTP2Backend(app=app)
|
||||
|
||||
response = await conn.send(request)
|
||||
with Client(backend=backend) as client:
|
||||
response = client.post("http://example.org", data=b"<data>")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.content) == {
|
||||
@ -109,21 +40,13 @@ async def test_http2_post_request():
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_multiple_requests():
|
||||
server = MockServer()
|
||||
conn = HTTP2Connection(reader=server, writer=server)
|
||||
request_1 = Request("GET", "http://example.org/1")
|
||||
request_2 = Request("GET", "http://example.org/2")
|
||||
request_3 = Request("GET", "http://example.org/3")
|
||||
def test_http2_multiple_requests():
|
||||
backend = MockHTTP2Backend(app=app)
|
||||
|
||||
request_1.prepare()
|
||||
request_2.prepare()
|
||||
request_3.prepare()
|
||||
|
||||
response_1 = await conn.send(request_1)
|
||||
response_2 = await conn.send(request_2)
|
||||
response_3 = await conn.send(request_3)
|
||||
with Client(backend=backend) as client:
|
||||
response_1 = client.get("http://example.org/1")
|
||||
response_2 = client.get("http://example.org/2")
|
||||
response_3 = client.get("http://example.org/3")
|
||||
|
||||
assert response_1.status_code == 200
|
||||
assert json.loads(response_1.content) == {"method": "GET", "path": "/1", "body": ""}
|
||||
@ -133,5 +56,3 @@ async def test_http2_multiple_requests():
|
||||
|
||||
assert response_3.status_code == 200
|
||||
assert json.loads(response_3.content) == {"method": "GET", "path": "/3", "body": ""}
|
||||
|
||||
await conn.close()
|
||||
|
||||
108
tests/dispatch/utils.py
Normal file
108
tests/dispatch/utils.py
Normal file
@ -0,0 +1,108 @@
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
import h2.config
|
||||
import h2.connection
|
||||
import h2.events
|
||||
|
||||
from httpcore import AsyncioBackend, BaseReader, BaseWriter, Protocol, Request, TimeoutConfig
|
||||
|
||||
|
||||
class MockHTTP2Backend(AsyncioBackend):
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
ssl_context: typing.Optional[ssl.SSLContext],
|
||||
timeout: TimeoutConfig,
|
||||
) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
|
||||
server = MockHTTP2Server(self.app)
|
||||
return (server, server, Protocol.HTTP_2)
|
||||
|
||||
|
||||
class MockHTTP2Server(BaseReader, BaseWriter):
|
||||
"""
|
||||
This class exposes Reader and Writer style interfaces.
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
config = h2.config.H2Configuration(client_side=False)
|
||||
self.conn = h2.connection.H2Connection(config=config)
|
||||
self.app = app
|
||||
self.buffer = b""
|
||||
self.requests = {}
|
||||
|
||||
# BaseReader interface
|
||||
|
||||
async def read(self, n, timeout) -> bytes:
|
||||
send, self.buffer = self.buffer[:n], self.buffer[n:]
|
||||
return send
|
||||
|
||||
# BaseWriter interface
|
||||
|
||||
def write_no_block(self, data: bytes) -> None:
|
||||
events = self.conn.receive_data(data)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.RequestReceived):
|
||||
self.request_received(event.headers, event.stream_id)
|
||||
elif isinstance(event, h2.events.DataReceived):
|
||||
self.receive_data(event.data, event.stream_id)
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
self.stream_complete(event.stream_id)
|
||||
|
||||
async def write(self, data: bytes, timeout) -> None:
|
||||
self.write_no_block(data)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
# Server implementation
|
||||
|
||||
def request_received(self, headers, stream_id):
|
||||
"""
|
||||
Handler for when the initial part of the HTTP request is received.
|
||||
"""
|
||||
if stream_id not in self.requests:
|
||||
self.requests[stream_id] = []
|
||||
self.requests[stream_id].append({"headers": headers, "data": b""})
|
||||
|
||||
def receive_data(self, data, stream_id):
|
||||
"""
|
||||
Handler for when a data part of the HTTP request is received.
|
||||
"""
|
||||
self.requests[stream_id][-1]["data"] += data
|
||||
|
||||
def stream_complete(self, stream_id):
|
||||
"""
|
||||
Handler for when the HTTP request is completed.
|
||||
"""
|
||||
request = self.requests[stream_id].pop(0)
|
||||
if not self.requests[stream_id]:
|
||||
del self.requests[stream_id]
|
||||
|
||||
headers_dict = dict(request["headers"])
|
||||
|
||||
method = headers_dict[b":method"].decode("ascii")
|
||||
url = "%s://%s%s" % (
|
||||
headers_dict[b":scheme"].decode("ascii"),
|
||||
headers_dict[b":authority"].decode("ascii"),
|
||||
headers_dict[b":path"].decode("ascii"),
|
||||
)
|
||||
headers = [(k, v) for k, v in request["headers"] if not k.startswith(b":")]
|
||||
data = request["data"]
|
||||
|
||||
# Call out to the app.
|
||||
request = Request(method, url, headers=headers, data=data)
|
||||
response = self.app(request)
|
||||
|
||||
# Write the response to the buffer.
|
||||
status_code_bytes = str(int(response.status_code)).encode("ascii")
|
||||
response_headers = [(b":status", status_code_bytes)] + response.headers.raw
|
||||
|
||||
self.conn.send_headers(stream_id, response_headers)
|
||||
self.conn.send_data(stream_id, response.content, end_stream=True)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
Loading…
Reference in New Issue
Block a user