Tighten up connection acquiry/release
This commit is contained in:
parent
286f04f1a6
commit
28c505a70f
@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import typing
|
||||
|
||||
import h2.connection
|
||||
@ -17,12 +18,12 @@ class HTTPConnection(Client):
|
||||
origin: typing.Union[str, Origin],
|
||||
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
on_release: typing.Callable = None,
|
||||
pool_release_func: typing.Callable = None,
|
||||
):
|
||||
self.origin = Origin(origin) if isinstance(origin, str) else origin
|
||||
self.ssl = ssl
|
||||
self.timeout = timeout
|
||||
self.on_release = on_release
|
||||
self.pool_release_func = pool_release_func
|
||||
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
|
||||
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
|
||||
|
||||
@ -43,6 +44,11 @@ class HTTPConnection(Client):
|
||||
port = self.origin.port
|
||||
ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
|
||||
|
||||
if self.pool_release_func is None:
|
||||
on_release = None
|
||||
else:
|
||||
on_release = functools.partial(self.pool_release_func, self)
|
||||
|
||||
reader, writer, protocol = await connect(
|
||||
hostname, port, ssl_context, timeout
|
||||
)
|
||||
@ -52,7 +58,7 @@ class HTTPConnection(Client):
|
||||
writer,
|
||||
origin=self.origin,
|
||||
timeout=self.timeout,
|
||||
on_release=self.on_release,
|
||||
on_release=on_release,
|
||||
)
|
||||
else:
|
||||
self.h11_connection = HTTP11Connection(
|
||||
@ -60,7 +66,7 @@ class HTTPConnection(Client):
|
||||
writer,
|
||||
origin=self.origin,
|
||||
timeout=self.timeout,
|
||||
on_release=self.on_release,
|
||||
on_release=on_release,
|
||||
)
|
||||
|
||||
if self.h2_connection is not None:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import collections.abc
|
||||
import typing
|
||||
|
||||
from .config import (
|
||||
@ -14,6 +15,66 @@ from .exceptions import PoolTimeout
|
||||
from .models import Client, Origin, Request, Response
|
||||
from .streams import PoolSemaphore
|
||||
|
||||
CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
|
||||
|
||||
|
||||
class ConnectionStore(collections.abc.Sequence):
|
||||
"""
|
||||
We need to maintain collections of connections in a way that allows us to:
|
||||
|
||||
* Lookup connections by origin.
|
||||
* Iterate over connections by insertion time.
|
||||
* Return the total number of connections.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.all = {} # type: typing.Dict[HTTPConnection, float]
|
||||
self.by_origin = (
|
||||
{}
|
||||
) # type: typing.Dict[Origin, typing.Dict[HTTPConnection, float]]
|
||||
|
||||
def pop_by_origin(self, origin: Origin) -> typing.Optional[HTTPConnection]:
|
||||
try:
|
||||
connections = self.by_origin[origin]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
connection = next(reversed(list(connections.keys())))
|
||||
del connections[connection]
|
||||
if not connections:
|
||||
del self.by_origin[origin]
|
||||
del self.all[connection]
|
||||
|
||||
return connection
|
||||
|
||||
def add(self, connection: HTTPConnection) -> None:
|
||||
self.all[connection] = 0.0
|
||||
try:
|
||||
self.by_origin[connection.origin][connection] = 0.0
|
||||
except KeyError:
|
||||
self.by_origin[connection.origin] = {connection: 0.0}
|
||||
|
||||
def remove(self, connection: HTTPConnection) -> None:
|
||||
del self.all[connection]
|
||||
del self.by_origin[connection.origin][connection]
|
||||
if not self.by_origin[connection.origin]:
|
||||
del self.by_origin[connection.origin]
|
||||
|
||||
def clear(self) -> None:
|
||||
self.all.clear()
|
||||
self.by_origin.clear()
|
||||
|
||||
def __iter__(self) -> typing.Iterator[HTTPConnection]:
|
||||
return iter(self.all.keys())
|
||||
|
||||
def __getitem__(self, key: typing.Any) -> typing.Any:
|
||||
if key in self.all:
|
||||
return key
|
||||
return None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.all)
|
||||
|
||||
|
||||
class ConnectionPool(Client):
|
||||
def __init__(
|
||||
@ -27,12 +88,14 @@ class ConnectionPool(Client):
|
||||
self.timeout = timeout
|
||||
self.limits = limits
|
||||
self.is_closed = False
|
||||
self.num_active_connections = 0
|
||||
self.num_keepalive_connections = 0
|
||||
self._keepalive_connections = (
|
||||
{}
|
||||
) # type: typing.Dict[Origin, typing.List[HTTPConnection]]
|
||||
self._max_connections = PoolSemaphore(limits, timeout)
|
||||
|
||||
self.max_connections = PoolSemaphore(limits, timeout)
|
||||
self.keepalive_connections = ConnectionStore()
|
||||
self.active_connections = ConnectionStore()
|
||||
|
||||
@property
|
||||
def num_connections(self) -> int:
|
||||
return len(self.keepalive_connections) + len(self.active_connections)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
@ -45,56 +108,42 @@ class ConnectionPool(Client):
|
||||
response = await connection.send(request, ssl=ssl, timeout=timeout)
|
||||
return response
|
||||
|
||||
@property
|
||||
def num_connections(self) -> int:
|
||||
return self.num_active_connections + self.num_keepalive_connections
|
||||
|
||||
async def acquire_connection(
|
||||
self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None
|
||||
) -> HTTPConnection:
|
||||
try:
|
||||
connection = self._keepalive_connections[origin].pop()
|
||||
if not self._keepalive_connections[origin]:
|
||||
del self._keepalive_connections[origin]
|
||||
self.num_keepalive_connections -= 1
|
||||
self.num_active_connections += 1
|
||||
connection = self.keepalive_connections.pop_by_origin(origin)
|
||||
|
||||
except (KeyError, IndexError):
|
||||
await self._max_connections.acquire(timeout)
|
||||
if connection is None:
|
||||
await self.max_connections.acquire(timeout)
|
||||
connection = HTTPConnection(
|
||||
origin,
|
||||
ssl=self.ssl,
|
||||
timeout=self.timeout,
|
||||
on_release=self.release_connection,
|
||||
pool_release_func=self.release_connection,
|
||||
)
|
||||
self.num_active_connections += 1
|
||||
|
||||
self.active_connections.add(connection)
|
||||
|
||||
return connection
|
||||
|
||||
async def release_connection(self, connection: HTTPConnection) -> None:
|
||||
if connection.is_closed:
|
||||
self._max_connections.release()
|
||||
self.num_active_connections -= 1
|
||||
self.active_connections.remove(connection)
|
||||
self.max_connections.release()
|
||||
elif (
|
||||
self.limits.soft_limit is not None
|
||||
and self.num_connections > self.limits.soft_limit
|
||||
):
|
||||
self._max_connections.release()
|
||||
self.num_active_connections -= 1
|
||||
self.active_connections.remove(connection)
|
||||
self.max_connections.release()
|
||||
await connection.close()
|
||||
else:
|
||||
self.num_active_connections -= 1
|
||||
self.num_keepalive_connections += 1
|
||||
try:
|
||||
self._keepalive_connections[connection.origin].append(connection)
|
||||
except KeyError:
|
||||
self._keepalive_connections[connection.origin] = [connection]
|
||||
self.active_connections.remove(connection)
|
||||
self.keepalive_connections.add(connection)
|
||||
|
||||
async def close(self) -> None:
|
||||
self.is_closed = True
|
||||
all_connections = []
|
||||
for connections in self._keepalive_connections.values():
|
||||
all_connections.extend(list(connections))
|
||||
self._keepalive_connections.clear()
|
||||
for connection in all_connections:
|
||||
connections = list(self.keepalive_connections)
|
||||
self.keepalive_connections.clear()
|
||||
for connection in connections:
|
||||
await connection.close()
|
||||
|
||||
@ -82,7 +82,7 @@ class HTTP11Connection(Client):
|
||||
protocol="HTTP/1.1",
|
||||
headers=headers,
|
||||
body=body,
|
||||
on_close=self._release,
|
||||
on_close=self.response_closed,
|
||||
)
|
||||
|
||||
async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]:
|
||||
@ -106,7 +106,7 @@ class HTTP11Connection(Client):
|
||||
|
||||
return event
|
||||
|
||||
async def _release(self) -> None:
|
||||
async def response_closed(self) -> None:
|
||||
if (
|
||||
self.h11_state.our_state is h11.DONE
|
||||
and self.h11_state.their_state is h11.DONE
|
||||
@ -116,7 +116,7 @@ class HTTP11Connection(Client):
|
||||
await self.close()
|
||||
|
||||
if self.on_release is not None:
|
||||
await self.on_release(self)
|
||||
await self.on_release()
|
||||
|
||||
async def close(self) -> None:
|
||||
event = h11.ConnectionClosed()
|
||||
|
||||
@ -141,7 +141,7 @@ class HTTP2Connection(Client):
|
||||
|
||||
async def release(self) -> None:
|
||||
if self.on_release is not None:
|
||||
await self.on_release(self)
|
||||
await self.on_release()
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.writer.close()
|
||||
|
||||
@ -14,7 +14,7 @@ import ssl
|
||||
import typing
|
||||
|
||||
from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
|
||||
from .exceptions import ConnectTimeout, ReadTimeout, PoolTimeout, WriteTimeout
|
||||
from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
|
||||
|
||||
OptionalTimeout = typing.Optional[TimeoutConfig]
|
||||
|
||||
|
||||
@ -10,12 +10,12 @@ async def test_keepalive_connections(server):
|
||||
"""
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -25,12 +25,12 @@ async def test_differing_connection_keys(server):
|
||||
"""
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
response = await http.request("GET", "http://localhost:8000/")
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 2
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -42,12 +42,12 @@ async def test_soft_limit(server):
|
||||
|
||||
async with httpcore.ConnectionPool(limits=limits) as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/")
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
response = await http.request("GET", "http://localhost:8000/")
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -57,13 +57,13 @@ async def test_streaming_response_holds_connection(server):
|
||||
"""
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
assert http.num_active_connections == 1
|
||||
assert http.num_keepalive_connections == 0
|
||||
assert len(http.active_connections) == 1
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
await response.read()
|
||||
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -73,20 +73,20 @@ async def test_multiple_concurrent_connections(server):
|
||||
"""
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
assert http.num_active_connections == 1
|
||||
assert http.num_keepalive_connections == 0
|
||||
assert len(http.active_connections) == 1
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
assert http.num_active_connections == 2
|
||||
assert http.num_keepalive_connections == 0
|
||||
assert len(http.active_connections) == 2
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
await response_b.read()
|
||||
assert http.num_active_connections == 1
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 1
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
await response_a.read()
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 2
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -97,8 +97,8 @@ async def test_close_connections(server):
|
||||
headers = [(b"connection", b"close")]
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 0
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -110,8 +110,8 @@ async def test_standard_response_close(server):
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
await response.read()
|
||||
await response.close()
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 1
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -122,5 +122,5 @@ async def test_premature_response_close(server):
|
||||
async with httpcore.ConnectionPool() as http:
|
||||
response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
|
||||
await response.close()
|
||||
assert http.num_active_connections == 0
|
||||
assert http.num_keepalive_connections == 0
|
||||
assert len(http.active_connections) == 0
|
||||
assert len(http.keepalive_connections) == 0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user