Connections

This commit is contained in:
Tom Christie 2019-04-06 13:18:39 +01:00
parent 6a4376b202
commit 86263fa073
13 changed files with 464 additions and 184 deletions

View File

@ -63,7 +63,8 @@ of it, and exposes only plain datastructures that reflect the network response.
```python
import httpcore
response = await httpcore.request('GET', 'http://example.com')
http = httpcore.ConnectionPool()
response = await http.request('GET', 'http://example.com')
assert response.status_code == 200
assert response.body == b'Hello, world'
```
@ -71,20 +72,22 @@ assert response.body == b'Hello, world'
Top-level API...
```python
response = await httpcore.request(method, url, [headers], [body], [stream])
http = httpcore.ConnectionPool([ssl], [timeout], [limits])
response = await http.request(method, url, [headers], [body], [stream])
```
Explicit PoolManager...
ConnectionPool as a context-manager...
```python
async with httpcore.PoolManager([ssl], [timeout], [limits]) as pool:
response = await pool.request(method, url, [headers], [body], [stream])
async with httpcore.ConnectionPool([ssl], [timeout], [limits]) as http:
response = await http.request(method, url, [headers], [body], [stream])
```
Streaming...
```python
response = await httpcore.request(method, url, stream=True)
http = httpcore.ConnectionPool()
response = await http.request(method, url, stream=True)
async for part in response.stream():
...
```
@ -100,7 +103,7 @@ import httpcore
class GatewayServer:
def __init__(self, base_url):
self.base_url = base_url
self.pool = httpcore.PoolManager()
self.http = httpcore.ConnectionPool()
async def __call__(self, scope, receive, send):
assert scope['type'] == 'http'
@ -122,7 +125,7 @@ class GatewayServer:
if not message.get('more_body', False):
break
response = await self.pool.request(
response = await self.http.request(
method, url, headers=headers, body=body, stream=True
)

View File

@ -1,5 +1,6 @@
from .api import PoolManager, Response, request
from .config import PoolLimits, SSLConfig, TimeoutConfig
from .datastructures import URL, Request, Response
from .exceptions import ResponseClosed, StreamConsumed
from .pool import ConnectionPool
__version__ = "0.0.2"

View File

@ -1,67 +0,0 @@
import typing
from types import TracebackType
from .config import (
DEFAULT_POOL_LIMITS,
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
PoolLimits,
SSLConfig,
TimeoutConfig,
)
from .models import Response
async def request(
method: str,
url: str,
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
stream: bool = False,
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
) -> Response:
async with PoolManager(ssl=ssl, timeout=timeout) as pool:
return await pool.request(
method=method, url=url, headers=headers, body=body, stream=stream
)
class PoolManager:
def __init__(
self,
*,
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
limits: PoolLimits = DEFAULT_POOL_LIMITS,
):
self.ssl = ssl
self.timeout = timeout
self.limits = limits
self.is_closed = False
async def request(
self,
method: str,
url: str,
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
stream: bool = False,
) -> Response:
raise NotImplementedError()
async def close(self) -> None:
self.is_closed = True
async def __aenter__(self) -> "PoolManager":
return self
async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
await self.close()

View File

@ -1,5 +1,7 @@
import typing
import certifi
class SSLConfig:
"""
@ -52,3 +54,4 @@ class PoolLimits:
DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True)
DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0)
DEFAULT_POOL_LIMITS = PoolLimits(max_hosts=10, conns_per_host=10, hard_limit=False)
DEFAULT_CA_BUNDLE_PATH = certifi.where()

View File

@ -1,62 +1,119 @@
from config import TimeoutConfig
import asyncio
import h11
import ssl
import typing
import h11
from .config import TimeoutConfig
from .datastructures import Request, Response
from .exceptions import ConnectTimeout, ReadTimeout
H11Event = typing.Union[
h11.Request,
h11.Response,
h11.InformationalResponse,
h11.Data,
h11.EndOfMessage,
h11.ConnectionClosed,
]
class Connection:
def __init__(self):
def __init__(self, timeout: TimeoutConfig):
self.reader = None
self.writer = None
self.state = h11.Connection(our_role=h11.CLIENT)
self.timeout = timeout
async def open(self, host: str, port: int, ssl: ssl.SSLContext):
async def open(
self,
hostname: str,
port: int,
*,
ssl: typing.Union[bool, ssl.SSLContext] = False
) -> None:
try:
self.reader, self.writer = await asyncio.wait_for(
asyncio.open_connection(host, port, ssl=ssl), timeout
self.reader, self.writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl),
self.timeout.connect_timeout,
)
except asyncio.TimeoutError:
raise ConnectTimeout()
async def send(self, request: Request) -> Response:
method = request.method
async def send(self, request: Request, stream: bool=False) -> Response:
method = request.method.encode()
target = request.url.target
host_header = (b"host", request.url.netloc.encode("ascii"))
if request.is_streaming:
content_length = (b"transfer-encoding", b"chunked")
else:
content_length = (b"content-length", str(len(request.body)).encode())
target = request.url.path
if request.url.query:
target += "?" + request.url.query
headers = [host_header, content_length] + request.headers
headers = [
("host", request.url.netloc)
] += request.headers
# Send the request method, path/query, and headers.
#  Start sending the request.
event = h11.Request(method=method, target=target, headers=headers)
await self._send_event(event)
# Send the request body.
if request.is_streaming:
async for data in request.raw():
async for data in request.stream():
event = h11.Data(data=data)
await self._send_event(event)
else:
elif request.body:
event = h11.Data(data=request.body)
await self._send_event(event)
# Finalize sending the request.
event = h11.EndOfMessage()
await connection.send_event(event)
await self._send_event(event)
async def _send_event(self, message):
data = self.state.send(message)
# Start getting the response.
event = await self._receive_event()
if isinstance(event, h11.InformationalResponse):
event = await self._receive_event()
assert isinstance(event, h11.Response)
status_code = event.status_code
headers = event.headers
if stream:
return Response(status_code=status_code, headers=headers, body=self.body_iter())
#  Get the response body.
body = b""
event = await self._receive_event()
while isinstance(event, h11.Data):
body += event.data
event = await self._receive_event()
assert isinstance(event, h11.EndOfMessage)
await self.close()
return Response(status_code=status_code, headers=headers, body=body)
async def body_iter(self) -> typing.Iterable[bytes]:
event = await self._receive_event()
while isinstance(event, h11.Data):
yield event.data
event = await self._receive_event()
assert isinstance(event, h11.EndOfMessage)
await self.close()
async def _send_event(self, event: H11Event) -> None:
assert self.writer is not None
data = self.state.send(event)
self.writer.write(data)
async def _receive_event(self, timeout):
async def _receive_event(self) -> H11Event:
assert self.reader is not None
event = self.state.next_event()
while type(event) is h11.NEED_DATA:
while event is h11.NEED_DATA:
try:
data = await asyncio.wait_for(self.reader.read(2048), timeout)
data = await asyncio.wait_for(
self.reader.read(2048), self.timeout.read_timeout
)
except asyncio.TimeoutError:
raise ReadTimeout()
self.state.receive_data(data)
@ -64,7 +121,8 @@ class Connection:
return event
async def close(self):
self.writer.close()
if hasattr(self.writer, "wait_closed"):
await self.writer.wait_closed()
async def close(self) -> None:
if self.writer is not None:
self.writer.close()
if hasattr(self.writer, "wait_closed"):
await self.writer.wait_closed()

145
httpcore/datastructures.py Normal file
View File

@ -0,0 +1,145 @@
import typing
from urllib.parse import urlsplit
from .decoders import IdentityDecoder
from .exceptions import ResponseClosed, StreamConsumed
class URL:
def __init__(self, url: str = "") -> None:
self.components = urlsplit(url)
if not self.components.scheme:
raise ValueError("No scheme included in URL.")
if self.components.scheme not in ("http", "https"):
raise ValueError('URL scheme must be "http" or "https".')
if not self.components.hostname:
raise ValueError("No hostname included in URL.")
@property
def scheme(self) -> str:
return self.components.scheme
@property
def netloc(self) -> str:
return self.components.netloc
@property
def path(self) -> str:
return self.components.path
@property
def query(self) -> str:
return self.components.query
@property
def hostname(self) -> str:
return self.components.hostname
@property
def port(self) -> int:
port = self.components.port
if port is None:
return {"https": 443, "http": 80}[self.scheme]
return port
@property
def target(self) -> str:
path = self.path or "/"
query = self.query
if query:
return path + "?" + query
return path
@property
def is_secure(self) -> bool:
return self.components.scheme == "https"
class Request:
def __init__(
self,
method: str,
url: URL,
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
):
self.method = method
self.url = url
self.headers = list(headers)
if isinstance(body, bytes):
self.is_streaming = False
self.body = body
else:
self.is_streaming = True
self.body_aiter = body
async def stream(self) -> typing.AsyncIterator[bytes]:
assert self.is_streaming
async for part in self.body_aiter:
yield part
class Response:
def __init__(
self,
status_code: int,
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
on_close: typing.Callable = None,
):
self.status_code = status_code
self.headers = list(headers)
self.on_close = on_close
self.is_closed = False
self.is_streamed = False
self.decoder = IdentityDecoder()
if isinstance(body, bytes):
self.is_closed = True
self.body = body
else:
self.body_aiter = body
async def read(self) -> bytes:
"""
Read and return the response content.
"""
if not hasattr(self, "body"):
body = b""
async for part in self.stream():
body += part
self.body = body
return self.body
async def stream(self) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
This will allow us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "body"):
yield self.body
else:
async for chunk in self.raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()
async def raw(self) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
if self.is_streamed:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()
self.is_streamed = True
async for part in self.body_aiter:
yield part
await self.close()
async def close(self) -> None:
if not self.is_closed:
self.is_closed = True
if self.on_close is not None:
await self.on_close()

View File

@ -8,7 +8,7 @@ class IdentityDecoder:
return chunk
def flush(self) -> bytes:
return b''
return b""
# class DeflateDecoder:

View File

@ -1,68 +0,0 @@
import typing
from .decoders import IdentityDecoder
from .exceptions import ResponseClosed, StreamConsumed
class Response:
def __init__(
self,
status_code: int,
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
on_close: typing.Callable = None,
):
self.status_code = status_code
self.headers = list(headers)
self.on_close = on_close
self.is_closed = False
self.is_streamed = False
self.decoder = IdentityDecoder()
if isinstance(body, bytes):
self.is_closed = True
self.body = body
else:
self.body_aiter = body
async def read(self) -> bytes:
"""
Read and return the response content.
"""
if not hasattr(self, "body"):
body = b""
async for part in self.stream():
body += part
self.body = body
return self.body
async def stream(self):
"""
A byte-iterator over the decoded response content.
This will allow us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "body"):
yield self.body
else:
async for chunk in self.raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()
async def raw(self) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
if self.is_streamed:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()
self.is_streamed = True
async for part in self.body_aiter():
yield part
await self.close()
async def close(self) -> None:
if not self.is_closed:
self.is_closed = True
if self.on_close is not None:
await self.on_close()

126
httpcore/pool.py Normal file
View File

@ -0,0 +1,126 @@
import asyncio
import os
import ssl
import typing
from types import TracebackType
from .config import (
DEFAULT_CA_BUNDLE_PATH,
DEFAULT_POOL_LIMITS,
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
PoolLimits,
SSLConfig,
TimeoutConfig,
)
from .connections import Connection
from .datastructures import URL, Request, Response
class ConnectionPool:
def __init__(
self,
*,
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
limits: PoolLimits = DEFAULT_POOL_LIMITS,
):
self.ssl_config = ssl
self.timeout = timeout
self.limits = limits
self.is_closed = False
async def request(
self,
method: str,
url: str,
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
stream: bool = False,
) -> Response:
parsed_url = URL(url)
request = Request(method, parsed_url, headers=headers, body=body)
ssl_context = await self.get_ssl_context(parsed_url)
connection = await self.acquire_connection(parsed_url, ssl=ssl_context)
response = await connection.send(request, stream=stream)
return response
async def acquire_connection(
self, url: URL, *, ssl: typing.Union[bool, ssl.SSLContext] = False
) -> Connection:
connection = Connection(timeout=self.timeout)
await connection.open(url.hostname, url.port, ssl=ssl)
return connection
async def get_ssl_context(self, url: URL) -> typing.Union[bool, ssl.SSLContext]:
if not url.is_secure:
return False
if not hasattr(self, "ssl_context"):
if not self.ssl_config.verify:
self.ssl_context = self.get_ssl_context_no_verify()
else:
# Run the SSL loading in a threadpool, since it makes disk accesses.
loop = asyncio.get_event_loop()
self.ssl_context = await loop.run_in_executor(
None, self.get_ssl_context_verify
)
return self.ssl_context
def get_ssl_context_no_verify(self) -> ssl.SSLContext:
"""
Return an SSL context for unverified connections.
"""
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_COMPRESSION
context.set_default_verify_paths()
return context
def get_ssl_context_verify(self) -> ssl.SSLContext:
"""
Return an SSL context for verified connections.
"""
cert = self.ssl_config.cert
verify = self.ssl_config.verify
if isinstance(verify, bool):
ca_bundle_path = DEFAULT_CA_BUNDLE_PATH
elif os.path.exists(verify):
ca_bundle_path = verify
else:
raise IOError(
"Could not find a suitable TLS CA certificate bundle, "
"invalid path: {}".format(verify)
)
context = ssl.create_default_context()
if os.path.isfile(ca_bundle_path):
context.load_verify_locations(cafile=ca_bundle_path)
elif os.path.isdir(ca_bundle_path):
context.load_verify_locations(capath=ca_bundle_path)
if cert is not None:
if isinstance(cert, str):
context.load_cert_chain(certfile=cert)
else:
context.load_cert_chain(certfile=cert[0], keyfile=cert[1])
return context
async def close(self) -> None:
self.is_closed = True
async def __aenter__(self) -> "ConnectionPool":
return self
async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
await self.close()

View File

@ -1,3 +1,4 @@
certifi
h11
# Testing
@ -9,3 +10,4 @@ mypy
pytest
pytest-asyncio
pytest-cov
uvicorn

34
tests/conftest.py Normal file
View File

@ -0,0 +1,34 @@
import asyncio
import json
import pytest
from uvicorn.config import Config
from uvicorn.main import Server
async def app(scope, receive, send):
assert scope['type'] == 'http'
await send({
'type': 'http.response.start',
'status': 200,
'headers': [
[b'content-type', b'text/plain'],
]
})
await send({
'type': 'http.response.body',
'body': b'Hello, world!',
})
@pytest.fixture
async def server():
config = Config(app=app, lifespan="off")
server = Server(config=config)
task = asyncio.ensure_future(server.serve())
try:
while not server.started:
await asyncio.sleep(0.0001)
yield server
finally:
task.cancel()

View File

@ -0,0 +1,38 @@
import pytest
import httpcore
@pytest.mark.asyncio
async def test_get(server):
async with httpcore.ConnectionPool() as http:
response = await http.request('GET', "http://127.0.0.1:8000/")
assert response.status_code == 200
assert response.body == b'Hello, world!'
@pytest.mark.asyncio
async def test_post(server):
async with httpcore.ConnectionPool() as http:
response = await http.request('POST', "http://127.0.0.1:8000/", body=b"Hello, world!")
assert response.status_code == 200
@pytest.mark.asyncio
async def test_stream_response(server):
async with httpcore.ConnectionPool() as http:
response = await http.request('GET', "http://127.0.0.1:8000/", stream=True)
assert response.status_code == 200
assert not hasattr(response, 'body')
body = await response.read()
assert body == b'Hello, world!'
@pytest.mark.asyncio
async def test_stream_request(server):
async def hello_world():
yield b"Hello, "
yield b"world!"
async with httpcore.ConnectionPool() as http:
response = await http.request('POST', "http://127.0.0.1:8000/", body=hello_world())
assert response.status_code == 200

View File

@ -3,17 +3,21 @@ import pytest
import httpcore
class MockRequests(httpcore.PoolManager):
async def request(self, method, url, *, headers = (), body = b'', stream = False) -> httpcore.Response:
class MockHTTP(httpcore.ConnectionPool):
async def request(
self, method, url, *, headers=(), body=b"", stream=False
) -> httpcore.Response:
if stream:
async def streaming_body():
yield b"Hello, "
yield b"world!"
return httpcore.Response(200, body=streaming_body)
return httpcore.Response(200, body=streaming_body())
return httpcore.Response(200, body=b"Hello, world!")
http = MockRequests()
http = MockHTTP()
@pytest.mark.asyncio
@ -47,7 +51,7 @@ async def test_stream_response():
assert response.body == b"Hello, world!"
assert response.is_closed
body = b''
body = b""
async for part in response.stream():
body += part
@ -61,7 +65,7 @@ async def test_read_streaming_response():
response = await http.request("GET", "http://example.com", stream=True)
assert response.status_code == 200
assert not hasattr(response, 'body')
assert not hasattr(response, "body")
assert not response.is_closed
body = await response.read()
@ -76,15 +80,15 @@ async def test_stream_streaming_response():
response = await http.request("GET", "http://example.com", stream=True)
assert response.status_code == 200
assert not hasattr(response, 'body')
assert not hasattr(response, "body")
assert not response.is_closed
body = b''
body = b""
async for part in response.stream():
body += part
assert body == b"Hello, world!"
assert not hasattr(response, 'body')
assert not hasattr(response, "body")
assert response.is_closed
@ -92,13 +96,14 @@ async def test_stream_streaming_response():
async def test_cannot_read_after_stream_consumed():
response = await http.request("GET", "http://example.com", stream=True)
body = b''
body = b""
async for part in response.stream():
body += part
with pytest.raises(httpcore.StreamConsumed):
await response.read()
@pytest.mark.asyncio
async def test_cannot_read_after_response_closed():
response = await http.request("GET", "http://example.com", stream=True)