Merge pull request #4 from encode/connection-pooling

Connection pooling
This commit is contained in:
Tom Christie 2019-04-17 17:03:29 +01:00 committed by GitHub
commit 68c468a67e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 253 additions and 41 deletions

View File

@ -45,13 +45,17 @@ class PoolLimits:
Limits on the number of connections in a connection pool.
"""
def __init__(self, *, max_hosts: int, conns_per_host: int, hard_limit: bool):
self.max_hosts = max_hosts
self.conns_per_host = conns_per_host
def __init__(
self,
*,
soft_limit: typing.Optional[int] = None,
hard_limit: typing.Optional[int] = None
):
self.soft_limit = soft_limit
self.hard_limit = hard_limit
DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True)
DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0)
DEFAULT_POOL_LIMITS = PoolLimits(max_hosts=10, conns_per_host=10, hard_limit=False)
DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100)
DEFAULT_CA_BUNDLE_PATH = certifi.where()

View File

@ -19,18 +19,19 @@ H11Event = typing.Union[
class Connection:
def __init__(self, timeout: TimeoutConfig):
def __init__(self, timeout: TimeoutConfig, on_release: typing.Callable = None):
self.reader = None
self.writer = None
self.state = h11.Connection(our_role=h11.CLIENT)
self.timeout = timeout
self.on_release = on_release
@property
def is_closed(self) -> bool:
return self.state.our_state in (h11.CLOSED, h11.ERROR)
async def open(
self,
hostname: str,
port: int,
*,
ssl: typing.Union[bool, ssl.SSLContext] = False
self, hostname: str, port: int, *, ssl: typing.Optional[ssl.SSLContext] = None
) -> None:
try:
self.reader, self.writer = await asyncio.wait_for( # type: ignore
@ -40,7 +41,7 @@ class Connection:
except asyncio.TimeoutError:
raise ConnectTimeout()
async def send(self, request: Request, stream: bool = False) -> Response:
async def send(self, request: Request) -> Response:
method = request.method.encode()
target = request.url.target
headers = request.headers
@ -69,29 +70,17 @@ class Connection:
assert isinstance(event, h11.Response)
status_code = event.status_code
headers = event.headers
body = self._body_iter()
return Response(
status_code=status_code, headers=headers, body=body, on_close=self._release
)
if stream:
body_iter = self.body_iter()
return Response(status_code=status_code, headers=headers, body=body_iter)
#  Get the response body.
body = b""
event = await self._receive_event()
while isinstance(event, h11.Data):
body += event.data
event = await self._receive_event()
assert isinstance(event, h11.EndOfMessage)
await self.close()
return Response(status_code=status_code, headers=headers, body=body)
async def body_iter(self) -> typing.AsyncIterator[bytes]:
async def _body_iter(self) -> typing.AsyncIterator[bytes]:
event = await self._receive_event()
while isinstance(event, h11.Data):
yield event.data
event = await self._receive_event()
assert isinstance(event, h11.EndOfMessage)
await self.close()
async def _send_event(self, event: H11Event) -> None:
assert self.writer is not None
@ -116,8 +105,27 @@ class Connection:
return event
async def close(self) -> None:
if self.writer is not None:
self.writer.close()
if hasattr(self.writer, "wait_closed"):
await self.writer.wait_closed()
async def _release(self) -> None:
assert self.writer is not None
if self.state.our_state is h11.DONE and self.state.their_state is h11.DONE:
self.state.start_next_cycle()
else:
self.close()
if self.on_release is not None:
await self.on_release(self)
def close(self) -> None:
assert self.writer is not None
event = h11.ConnectionClosed()
try:
# If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
self.state.send(event)
except h11.ProtocolError:
# If we're in some other state then it's a premature close,
# and we'll end up in h11.ERROR.
pass
self.writer.close()

View File

@ -190,6 +190,10 @@ class Response:
await self.close()
async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
if not self.is_closed:
self.is_closed = True
if self.on_close is not None:

View File

@ -1,4 +1,5 @@
import asyncio
import functools
import os
import ssl
import typing
@ -16,6 +17,22 @@ from .config import (
from .connections import Connection
from .datastructures import URL, Request, Response
ConnectionKey = typing.Tuple[str, str, int] # (scheme, host, port)
class ConnectionSemaphore:
def __init__(self, max_connections: int = None):
if max_connections is not None:
self.semaphore = asyncio.BoundedSemaphore(value=max_connections)
async def acquire(self) -> None:
if hasattr(self, "semaphore"):
await self.semaphore.acquire()
def release(self) -> None:
if hasattr(self, "semaphore"):
self.semaphore.release()
class ConnectionPool:
def __init__(
@ -29,6 +46,14 @@ class ConnectionPool:
self.timeout = timeout
self.limits = limits
self.is_closed = False
self.num_active_connections = 0
self.num_keepalive_connections = 0
self._connections = (
{}
) # type: typing.Dict[ConnectionKey, typing.List[Connection]]
self._connection_semaphore = ConnectionSemaphore(
max_connections=self.limits.hard_limit
)
async def request(
self,
@ -43,19 +68,62 @@ class ConnectionPool:
request = Request(method, parsed_url, headers=headers, body=body)
ssl_context = await self.get_ssl_context(parsed_url)
connection = await self.acquire_connection(parsed_url, ssl=ssl_context)
response = await connection.send(request, stream=stream)
response = await connection.send(request)
if not stream:
try:
await response.read()
finally:
await response.close()
return response
@property
def num_connections(self) -> int:
return self.num_active_connections + self.num_keepalive_connections
async def acquire_connection(
self, url: URL, *, ssl: typing.Union[bool, ssl.SSLContext] = False
self, url: URL, *, ssl: typing.Optional[ssl.SSLContext] = None
) -> Connection:
connection = Connection(timeout=self.timeout)
await connection.open(url.hostname, url.port, ssl=ssl)
key = (url.scheme, url.hostname, url.port)
try:
connection = self._connections[key].pop()
if not self._connections[key]:
del self._connections[key]
self.num_keepalive_connections -= 1
self.num_active_connections += 1
except (KeyError, IndexError):
await self._connection_semaphore.acquire()
release = functools.partial(self.release_connection, key=key)
connection = Connection(timeout=self.timeout, on_release=release)
self.num_active_connections += 1
await connection.open(url.hostname, url.port, ssl=ssl)
return connection
async def get_ssl_context(self, url: URL) -> typing.Union[bool, ssl.SSLContext]:
async def release_connection(
self, connection: Connection, key: ConnectionKey
) -> None:
if connection.is_closed:
self._connection_semaphore.release()
self.num_active_connections -= 1
elif (
self.limits.soft_limit is not None
and self.num_connections > self.limits.soft_limit
):
self._connection_semaphore.release()
self.num_active_connections -= 1
connection.close()
else:
self.num_active_connections -= 1
self.num_keepalive_connections += 1
try:
self._connections[key].append(connection)
except KeyError:
self._connections[key] = [connection]
async def get_ssl_context(self, url: URL) -> typing.Optional[ssl.SSLContext]:
if not url.is_secure:
return False
return None
if not hasattr(self, "ssl_context"):
if not self.ssl_config.verify:

View File

@ -47,7 +47,7 @@ setup(
author_email="tom@tomchristie.com",
packages=get_packages("httpcore"),
data_files=[("", ["LICENSE.md"])],
install_requires=["h11"],
install_requires=["h11", "certifi"],
classifiers=[
"Development Status :: 3 - Alpha",
"Environment :: Web Environment",

View File

@ -27,4 +27,6 @@ async def server():
await asyncio.sleep(0.0001)
yield server
finally:
task.cancel()
server.should_exit = True
server.force_exit = True
await task

126
tests/test_pool.py Normal file
View File

@ -0,0 +1,126 @@
import pytest
import httpcore
@pytest.mark.asyncio
async def test_keepalive_connections(server):
"""
Connections should default to staying in a keep-alive state.
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
response = await http.request("GET", "http://127.0.0.1:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
@pytest.mark.asyncio
async def test_differing_connection_keys(server):
"""
Connnections to differing connection keys should result in multiple connections.
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
response = await http.request("GET", "http://localhost:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 2
@pytest.mark.asyncio
async def test_soft_limit(server):
"""
The soft_limit config should limit the maximum number of keep-alive connections.
"""
limits = httpcore.PoolLimits(soft_limit=1)
async with httpcore.ConnectionPool(limits=limits) as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
response = await http.request("GET", "http://localhost:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
@pytest.mark.asyncio
async def test_streaming_response_holds_connection(server):
"""
A streaming request should hold the connection open until the response is read.
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
assert http.num_active_connections == 1
assert http.num_keepalive_connections == 0
await response.read()
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
@pytest.mark.asyncio
async def test_multiple_concurrent_connections(server):
"""
Multiple conncurrent requests should open multiple conncurrent connections.
"""
async with httpcore.ConnectionPool() as http:
response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
assert http.num_active_connections == 1
assert http.num_keepalive_connections == 0
response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
assert http.num_active_connections == 2
assert http.num_keepalive_connections == 0
await response_b.read()
assert http.num_active_connections == 1
assert http.num_keepalive_connections == 1
await response_a.read()
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 2
@pytest.mark.asyncio
async def test_close_connections(server):
"""
Using a `Connection: close` header should close the connection.
"""
headers = [(b"connection", b"close")]
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 0
@pytest.mark.asyncio
async def test_standard_response_close(server):
"""
A standard close should keep the connection open.
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
await response.read()
await response.close()
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
@pytest.mark.asyncio
async def test_premature_response_close(server):
"""
A premature close should close the connection.
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
await response.close()
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 0