Tighten up connection acquiry/release

This commit is contained in:
Tom Christie 2019-04-25 15:31:47 +01:00
parent 286f04f1a6
commit 28c505a70f
6 changed files with 129 additions and 74 deletions

View File

@ -1,3 +1,4 @@
import functools
import typing
import h2.connection
@ -17,12 +18,12 @@ class HTTPConnection(Client):
origin: typing.Union[str, Origin],
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
on_release: typing.Callable = None,
pool_release_func: typing.Callable = None,
):
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = ssl
self.timeout = timeout
self.on_release = on_release
self.pool_release_func = pool_release_func
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
@ -43,6 +44,11 @@ class HTTPConnection(Client):
port = self.origin.port
ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
if self.pool_release_func is None:
on_release = None
else:
on_release = functools.partial(self.pool_release_func, self)
reader, writer, protocol = await connect(
hostname, port, ssl_context, timeout
)
@ -52,7 +58,7 @@ class HTTPConnection(Client):
writer,
origin=self.origin,
timeout=self.timeout,
on_release=self.on_release,
on_release=on_release,
)
else:
self.h11_connection = HTTP11Connection(
@ -60,7 +66,7 @@ class HTTPConnection(Client):
writer,
origin=self.origin,
timeout=self.timeout,
on_release=self.on_release,
on_release=on_release,
)
if self.h2_connection is not None:

View File

@ -1,3 +1,4 @@
import collections.abc
import typing
from .config import (
@ -14,6 +15,66 @@ from .exceptions import PoolTimeout
from .models import Client, Origin, Request, Response
from .streams import PoolSemaphore
CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
class ConnectionStore(collections.abc.Sequence):
"""
We need to maintain collections of connections in a way that allows us to:
* Lookup connections by origin.
* Iterate over connections by insertion time.
* Return the total number of connections.
"""
def __init__(self) -> None:
self.all = {} # type: typing.Dict[HTTPConnection, float]
self.by_origin = (
{}
) # type: typing.Dict[Origin, typing.Dict[HTTPConnection, float]]
def pop_by_origin(self, origin: Origin) -> typing.Optional[HTTPConnection]:
try:
connections = self.by_origin[origin]
except KeyError:
return None
connection = next(reversed(list(connections.keys())))
del connections[connection]
if not connections:
del self.by_origin[origin]
del self.all[connection]
return connection
def add(self, connection: HTTPConnection) -> None:
self.all[connection] = 0.0
try:
self.by_origin[connection.origin][connection] = 0.0
except KeyError:
self.by_origin[connection.origin] = {connection: 0.0}
def remove(self, connection: HTTPConnection) -> None:
del self.all[connection]
del self.by_origin[connection.origin][connection]
if not self.by_origin[connection.origin]:
del self.by_origin[connection.origin]
def clear(self) -> None:
self.all.clear()
self.by_origin.clear()
def __iter__(self) -> typing.Iterator[HTTPConnection]:
return iter(self.all.keys())
def __getitem__(self, key: typing.Any) -> typing.Any:
if key in self.all:
return key
return None
def __len__(self) -> int:
return len(self.all)
class ConnectionPool(Client):
def __init__(
@ -27,12 +88,14 @@ class ConnectionPool(Client):
self.timeout = timeout
self.limits = limits
self.is_closed = False
self.num_active_connections = 0
self.num_keepalive_connections = 0
self._keepalive_connections = (
{}
) # type: typing.Dict[Origin, typing.List[HTTPConnection]]
self._max_connections = PoolSemaphore(limits, timeout)
self.max_connections = PoolSemaphore(limits, timeout)
self.keepalive_connections = ConnectionStore()
self.active_connections = ConnectionStore()
@property
def num_connections(self) -> int:
return len(self.keepalive_connections) + len(self.active_connections)
async def send(
self,
@ -45,56 +108,42 @@ class ConnectionPool(Client):
response = await connection.send(request, ssl=ssl, timeout=timeout)
return response
@property
def num_connections(self) -> int:
return self.num_active_connections + self.num_keepalive_connections
async def acquire_connection(
self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None
) -> HTTPConnection:
try:
connection = self._keepalive_connections[origin].pop()
if not self._keepalive_connections[origin]:
del self._keepalive_connections[origin]
self.num_keepalive_connections -= 1
self.num_active_connections += 1
connection = self.keepalive_connections.pop_by_origin(origin)
except (KeyError, IndexError):
await self._max_connections.acquire(timeout)
if connection is None:
await self.max_connections.acquire(timeout)
connection = HTTPConnection(
origin,
ssl=self.ssl,
timeout=self.timeout,
on_release=self.release_connection,
pool_release_func=self.release_connection,
)
self.num_active_connections += 1
self.active_connections.add(connection)
return connection
async def release_connection(self, connection: HTTPConnection) -> None:
if connection.is_closed:
self._max_connections.release()
self.num_active_connections -= 1
self.active_connections.remove(connection)
self.max_connections.release()
elif (
self.limits.soft_limit is not None
and self.num_connections > self.limits.soft_limit
):
self._max_connections.release()
self.num_active_connections -= 1
self.active_connections.remove(connection)
self.max_connections.release()
await connection.close()
else:
self.num_active_connections -= 1
self.num_keepalive_connections += 1
try:
self._keepalive_connections[connection.origin].append(connection)
except KeyError:
self._keepalive_connections[connection.origin] = [connection]
self.active_connections.remove(connection)
self.keepalive_connections.add(connection)
async def close(self) -> None:
self.is_closed = True
all_connections = []
for connections in self._keepalive_connections.values():
all_connections.extend(list(connections))
self._keepalive_connections.clear()
for connection in all_connections:
connections = list(self.keepalive_connections)
self.keepalive_connections.clear()
for connection in connections:
await connection.close()

View File

@ -82,7 +82,7 @@ class HTTP11Connection(Client):
protocol="HTTP/1.1",
headers=headers,
body=body,
on_close=self._release,
on_close=self.response_closed,
)
async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]:
@ -106,7 +106,7 @@ class HTTP11Connection(Client):
return event
async def _release(self) -> None:
async def response_closed(self) -> None:
if (
self.h11_state.our_state is h11.DONE
and self.h11_state.their_state is h11.DONE
@ -116,7 +116,7 @@ class HTTP11Connection(Client):
await self.close()
if self.on_release is not None:
await self.on_release(self)
await self.on_release()
async def close(self) -> None:
event = h11.ConnectionClosed()

View File

@ -141,7 +141,7 @@ class HTTP2Connection(Client):
async def release(self) -> None:
if self.on_release is not None:
await self.on_release(self)
await self.on_release()
async def close(self) -> None:
await self.writer.close()

View File

@ -14,7 +14,7 @@ import ssl
import typing
from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
from .exceptions import ConnectTimeout, ReadTimeout, PoolTimeout, WriteTimeout
from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
OptionalTimeout = typing.Optional[TimeoutConfig]

View File

@ -10,12 +10,12 @@ async def test_keepalive_connections(server):
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://127.0.0.1:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@pytest.mark.asyncio
@ -25,12 +25,12 @@ async def test_differing_connection_keys(server):
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://localhost:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 2
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 2
@pytest.mark.asyncio
@ -42,12 +42,12 @@ async def test_soft_limit(server):
async with httpcore.ConnectionPool(limits=limits) as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://localhost:8000/")
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@pytest.mark.asyncio
@ -57,13 +57,13 @@ async def test_streaming_response_holds_connection(server):
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
assert http.num_active_connections == 1
assert http.num_keepalive_connections == 0
assert len(http.active_connections) == 1
assert len(http.keepalive_connections) == 0
await response.read()
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@pytest.mark.asyncio
@ -73,20 +73,20 @@ async def test_multiple_concurrent_connections(server):
"""
async with httpcore.ConnectionPool() as http:
response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
assert http.num_active_connections == 1
assert http.num_keepalive_connections == 0
assert len(http.active_connections) == 1
assert len(http.keepalive_connections) == 0
response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
assert http.num_active_connections == 2
assert http.num_keepalive_connections == 0
assert len(http.active_connections) == 2
assert len(http.keepalive_connections) == 0
await response_b.read()
assert http.num_active_connections == 1
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 1
assert len(http.keepalive_connections) == 1
await response_a.read()
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 2
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 2
@pytest.mark.asyncio
@ -97,8 +97,8 @@ async def test_close_connections(server):
headers = [(b"connection", b"close")]
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 0
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 0
@pytest.mark.asyncio
@ -110,8 +110,8 @@ async def test_standard_response_close(server):
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
await response.read()
await response.close()
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 1
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
@pytest.mark.asyncio
@ -122,5 +122,5 @@ async def test_premature_response_close(server):
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
await response.close()
assert http.num_active_connections == 0
assert http.num_keepalive_connections == 0
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 0