diff --git a/httpcore/config.py b/httpcore/config.py index d169e0af..2db89342 100644 --- a/httpcore/config.py +++ b/httpcore/config.py @@ -45,13 +45,17 @@ class PoolLimits: Limits on the number of connections in a connection pool. """ - def __init__(self, *, max_hosts: int, conns_per_host: int, hard_limit: bool): - self.max_hosts = max_hosts - self.conns_per_host = conns_per_host + def __init__( + self, + *, + soft_limit: typing.Optional[int] = None, + hard_limit: typing.Optional[int] = None + ): + self.soft_limit = soft_limit self.hard_limit = hard_limit DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True) DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0) -DEFAULT_POOL_LIMITS = PoolLimits(max_hosts=10, conns_per_host=10, hard_limit=False) +DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100) DEFAULT_CA_BUNDLE_PATH = certifi.where() diff --git a/httpcore/connections.py b/httpcore/connections.py index e181cb50..205482e8 100644 --- a/httpcore/connections.py +++ b/httpcore/connections.py @@ -19,18 +19,19 @@ H11Event = typing.Union[ class Connection: - def __init__(self, timeout: TimeoutConfig): + def __init__(self, timeout: TimeoutConfig, on_release: typing.Callable = None): self.reader = None self.writer = None self.state = h11.Connection(our_role=h11.CLIENT) self.timeout = timeout + self.on_release = on_release + + @property + def is_closed(self) -> bool: + return self.state.our_state in (h11.CLOSED, h11.ERROR) async def open( - self, - hostname: str, - port: int, - *, - ssl: typing.Union[bool, ssl.SSLContext] = False + self, hostname: str, port: int, *, ssl: typing.Optional[ssl.SSLContext] = None ) -> None: try: self.reader, self.writer = await asyncio.wait_for( # type: ignore @@ -40,7 +41,7 @@ class Connection: except asyncio.TimeoutError: raise ConnectTimeout() - async def send(self, request: Request, stream: bool = False) -> Response: + async def send(self, request: Request) -> Response: method = request.method.encode() target = request.url.target headers = request.headers @@ -69,29 +70,17 @@ class Connection: assert isinstance(event, h11.Response) status_code = event.status_code headers = event.headers + body = self._body_iter() + return Response( + status_code=status_code, headers=headers, body=body, on_close=self._release + ) - if stream: - body_iter = self.body_iter() - return Response(status_code=status_code, headers=headers, body=body_iter) - - #  Get the response body. - body = b"" - event = await self._receive_event() - while isinstance(event, h11.Data): - body += event.data - event = await self._receive_event() - assert isinstance(event, h11.EndOfMessage) - await self.close() - - return Response(status_code=status_code, headers=headers, body=body) - - async def body_iter(self) -> typing.AsyncIterator[bytes]: + async def _body_iter(self) -> typing.AsyncIterator[bytes]: event = await self._receive_event() while isinstance(event, h11.Data): yield event.data event = await self._receive_event() assert isinstance(event, h11.EndOfMessage) - await self.close() async def _send_event(self, event: H11Event) -> None: assert self.writer is not None @@ -116,8 +105,27 @@ class Connection: return event - async def close(self) -> None: - if self.writer is not None: - self.writer.close() - if hasattr(self.writer, "wait_closed"): - await self.writer.wait_closed() + async def _release(self) -> None: + assert self.writer is not None + + if self.state.our_state is h11.DONE and self.state.their_state is h11.DONE: + self.state.start_next_cycle() + else: + self.close() + + if self.on_release is not None: + await self.on_release(self) + + def close(self) -> None: + assert self.writer is not None + + event = h11.ConnectionClosed() + try: + # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED. + self.state.send(event) + except h11.ProtocolError: + # If we're in some other state then it's a premature close, + # and we'll end up in h11.ERROR. + pass + + self.writer.close() diff --git a/httpcore/datastructures.py b/httpcore/datastructures.py index 28f3bf91..5f84b466 100644 --- a/httpcore/datastructures.py +++ b/httpcore/datastructures.py @@ -190,6 +190,10 @@ class Response: await self.close() async def close(self) -> None: + """ + Close the response and release the connection. + Automatically called if the response body is read to completion. + """ if not self.is_closed: self.is_closed = True if self.on_close is not None: diff --git a/httpcore/pool.py b/httpcore/pool.py index 75948477..fcec56ea 100644 --- a/httpcore/pool.py +++ b/httpcore/pool.py @@ -1,4 +1,5 @@ import asyncio +import functools import os import ssl import typing @@ -16,6 +17,22 @@ from .config import ( from .connections import Connection from .datastructures import URL, Request, Response +ConnectionKey = typing.Tuple[str, str, int] # (scheme, host, port) + + +class ConnectionSemaphore: + def __init__(self, max_connections: int = None): + if max_connections is not None: + self.semaphore = asyncio.BoundedSemaphore(value=max_connections) + + async def acquire(self) -> None: + if hasattr(self, "semaphore"): + await self.semaphore.acquire() + + def release(self) -> None: + if hasattr(self, "semaphore"): + self.semaphore.release() + class ConnectionPool: def __init__( @@ -29,6 +46,14 @@ class ConnectionPool: self.timeout = timeout self.limits = limits self.is_closed = False + self.num_active_connections = 0 + self.num_keepalive_connections = 0 + self._connections = ( + {} + ) # type: typing.Dict[ConnectionKey, typing.List[Connection]] + self._connection_semaphore = ConnectionSemaphore( + max_connections=self.limits.hard_limit + ) async def request( self, @@ -43,19 +68,62 @@ class ConnectionPool: request = Request(method, parsed_url, headers=headers, body=body) ssl_context = await self.get_ssl_context(parsed_url) connection = await self.acquire_connection(parsed_url, ssl=ssl_context) - response = await connection.send(request, stream=stream) + response = await connection.send(request) + if not stream: + try: + await response.read() + finally: + await response.close() return response + @property + def num_connections(self) -> int: + return self.num_active_connections + self.num_keepalive_connections + async def acquire_connection( - self, url: URL, *, ssl: typing.Union[bool, ssl.SSLContext] = False + self, url: URL, *, ssl: typing.Optional[ssl.SSLContext] = None ) -> Connection: - connection = Connection(timeout=self.timeout) - await connection.open(url.hostname, url.port, ssl=ssl) + key = (url.scheme, url.hostname, url.port) + try: + connection = self._connections[key].pop() + if not self._connections[key]: + del self._connections[key] + self.num_keepalive_connections -= 1 + self.num_active_connections += 1 + + except (KeyError, IndexError): + await self._connection_semaphore.acquire() + release = functools.partial(self.release_connection, key=key) + connection = Connection(timeout=self.timeout, on_release=release) + self.num_active_connections += 1 + await connection.open(url.hostname, url.port, ssl=ssl) + return connection - async def get_ssl_context(self, url: URL) -> typing.Union[bool, ssl.SSLContext]: + async def release_connection( + self, connection: Connection, key: ConnectionKey + ) -> None: + if connection.is_closed: + self._connection_semaphore.release() + self.num_active_connections -= 1 + elif ( + self.limits.soft_limit is not None + and self.num_connections > self.limits.soft_limit + ): + self._connection_semaphore.release() + self.num_active_connections -= 1 + connection.close() + else: + self.num_active_connections -= 1 + self.num_keepalive_connections += 1 + try: + self._connections[key].append(connection) + except KeyError: + self._connections[key] = [connection] + + async def get_ssl_context(self, url: URL) -> typing.Optional[ssl.SSLContext]: if not url.is_secure: - return False + return None if not hasattr(self, "ssl_context"): if not self.ssl_config.verify: diff --git a/setup.py b/setup.py index bf42cf4c..6d709be2 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ setup( author_email="tom@tomchristie.com", packages=get_packages("httpcore"), data_files=[("", ["LICENSE.md"])], - install_requires=["h11"], + install_requires=["h11", "certifi"], classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Web Environment", diff --git a/tests/conftest.py b/tests/conftest.py index 234cf43b..efb79df1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,4 +27,6 @@ async def server(): await asyncio.sleep(0.0001) yield server finally: - task.cancel() + server.should_exit = True + server.force_exit = True + await task diff --git a/tests/test_pool.py b/tests/test_pool.py new file mode 100644 index 00000000..77a22157 --- /dev/null +++ b/tests/test_pool.py @@ -0,0 +1,126 @@ +import pytest + +import httpcore + + +@pytest.mark.asyncio +async def test_keepalive_connections(server): + """ + Connections should default to staying in a keep-alive state. + """ + async with httpcore.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/") + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 1 + + response = await http.request("GET", "http://127.0.0.1:8000/") + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 1 + + +@pytest.mark.asyncio +async def test_differing_connection_keys(server): + """ + Connnections to differing connection keys should result in multiple connections. + """ + async with httpcore.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/") + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 1 + + response = await http.request("GET", "http://localhost:8000/") + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 2 + + +@pytest.mark.asyncio +async def test_soft_limit(server): + """ + The soft_limit config should limit the maximum number of keep-alive connections. + """ + limits = httpcore.PoolLimits(soft_limit=1) + + async with httpcore.ConnectionPool(limits=limits) as http: + response = await http.request("GET", "http://127.0.0.1:8000/") + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 1 + + response = await http.request("GET", "http://localhost:8000/") + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 1 + + +@pytest.mark.asyncio +async def test_streaming_response_holds_connection(server): + """ + A streaming request should hold the connection open until the response is read. + """ + async with httpcore.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + assert http.num_active_connections == 1 + assert http.num_keepalive_connections == 0 + + await response.read() + + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 1 + + +@pytest.mark.asyncio +async def test_multiple_concurrent_connections(server): + """ + Multiple conncurrent requests should open multiple conncurrent connections. + """ + async with httpcore.ConnectionPool() as http: + response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + assert http.num_active_connections == 1 + assert http.num_keepalive_connections == 0 + + response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + assert http.num_active_connections == 2 + assert http.num_keepalive_connections == 0 + + await response_b.read() + assert http.num_active_connections == 1 + assert http.num_keepalive_connections == 1 + + await response_a.read() + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 2 + + +@pytest.mark.asyncio +async def test_close_connections(server): + """ + Using a `Connection: close` header should close the connection. + """ + headers = [(b"connection", b"close")] + async with httpcore.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers) + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 0 + + +@pytest.mark.asyncio +async def test_standard_response_close(server): + """ + A standard close should keep the connection open. + """ + async with httpcore.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + await response.read() + await response.close() + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 1 + + +@pytest.mark.asyncio +async def test_premature_response_close(server): + """ + A premature close should close the connection. + """ + async with httpcore.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + await response.close() + assert http.num_active_connections == 0 + assert http.num_keepalive_connections == 0