Add PoolSemaphore

This commit is contained in:
Tom Christie 2019-04-25 12:57:18 +01:00
parent 9ab607d8a3
commit 286f04f1a6
3 changed files with 77 additions and 51 deletions

View File

@ -1,4 +1,3 @@
import asyncio
import typing
from .config import (
@ -13,6 +12,7 @@ from .config import (
from .connection import HTTPConnection
from .exceptions import PoolTimeout
from .models import Client, Origin, Request, Response
from .streams import PoolSemaphore
class ConnectionPool(Client):
@ -32,9 +32,7 @@ class ConnectionPool(Client):
self._keepalive_connections = (
{}
) # type: typing.Dict[Origin, typing.List[HTTPConnection]]
self._max_connections = ConnectionSemaphore(
max_connections=self.limits.hard_limit
)
self._max_connections = PoolSemaphore(limits, timeout)
async def send(
self,
@ -62,15 +60,7 @@ class ConnectionPool(Client):
self.num_active_connections += 1
except (KeyError, IndexError):
if timeout is None:
pool_timeout = self.timeout.pool_timeout
else:
pool_timeout = timeout.pool_timeout
try:
await asyncio.wait_for(self._max_connections.acquire(), pool_timeout)
except asyncio.TimeoutError:
raise PoolTimeout()
await self._max_connections.acquire(timeout)
connection = HTTPConnection(
origin,
ssl=self.ssl,
@ -108,25 +98,3 @@ class ConnectionPool(Client):
self._keepalive_connections.clear()
for connection in all_connections:
await connection.close()
class ConnectionSemaphore:
def __init__(self, max_connections: int = None):
self.max_connections = max_connections
@property
def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
if not hasattr(self, "_semaphore"):
if self.max_connections is None:
self._semaphore = None
else:
self._semaphore = asyncio.BoundedSemaphore(value=self.max_connections)
return self._semaphore
async def acquire(self) -> None:
if self.semaphore is not None:
await self.semaphore.acquire()
def release(self) -> None:
if self.semaphore is not None:
self.semaphore.release()

View File

@ -2,7 +2,9 @@
The `Reader` and `Writer` classes here provide a lightweight layer over
`asyncio.StreamReader` and `asyncio.StreamWriter`.
They help encapsulate the timeout logic, make it easier to unit-test
Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`.
These classes help encapsulate the timeout logic, make it easier to unit-test
protocols, and help keep the rest of the package more `async`/`await`
based, and less strictly `asyncio`-specific.
"""
@ -11,8 +13,8 @@ import enum
import ssl
import typing
from .config import TimeoutConfig, DEFAULT_TIMEOUT_CONFIG
from .exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
from .exceptions import ConnectTimeout, ReadTimeout, PoolTimeout, WriteTimeout
OptionalTimeout = typing.Optional[TimeoutConfig]
@ -38,6 +40,17 @@ class BaseWriter:
raise NotImplementedError() # pragma: no cover
class BasePoolSemaphore:
def __init__(self, limits: PoolLimits, timeout: TimeoutConfig):
raise NotImplementedError() # pragma: no cover
async def acquire(self, timeout: OptionalTimeout = None) -> None:
raise NotImplementedError() # pragma: no cover
def release(self) -> None:
raise NotImplementedError() # pragma: no cover
class Reader(BaseReader):
def __init__(
self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
@ -86,6 +99,40 @@ class Writer(BaseWriter):
self.stream_writer.close()
class PoolSemaphore(BasePoolSemaphore):
def __init__(self, limits: PoolLimits, timeout: TimeoutConfig):
self.limits = limits
self.timeout = timeout
@property
def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
if not hasattr(self, "_semaphore"):
max_connections = self.limits.hard_limit
if max_connections is None:
self._semaphore = None
else:
self._semaphore = asyncio.BoundedSemaphore(value=max_connections)
return self._semaphore
async def acquire(self, timeout: OptionalTimeout = None) -> None:
if self.semaphore is None:
return
if timeout is None:
timeout = self.timeout
try:
await asyncio.wait_for(self.semaphore.acquire(), timeout.pool_timeout)
except asyncio.TimeoutError:
raise PoolTimeout()
def release(self) -> None:
if self.semaphore is None:
return
self.semaphore.release()
async def connect(
hostname: str,
port: int,

View File

@ -1,8 +1,9 @@
import json
import h2.config
import h2.connection
import h2.events
import pytest
import json
import httpcore
@ -61,15 +62,15 @@ class MockServer(httpcore.BaseReader, httpcore.BaseWriter):
request_headers = dict(request["headers"])
request_data = request["data"]
response_body = json.dumps({
"method": request_headers[b":method"].decode(),
"path": request_headers[b":path"].decode(),
"body": request_data.decode()
}).encode()
response_body = json.dumps(
{
"method": request_headers[b":method"].decode(),
"path": request_headers[b":path"].decode(),
"body": request_data.decode(),
}
).encode()
response_headers = (
(b":status", b"200"),
)
response_headers = ((b":status", b"200"),)
self.conn.send_headers(stream_id, response_headers)
self.conn.send_data(stream_id, response_body, end_stream=True)
self.buffer += self.conn.data_to_send()
@ -79,7 +80,9 @@ class MockServer(httpcore.BaseReader, httpcore.BaseWriter):
async def test_http2_get_request():
server = MockServer()
origin = httpcore.Origin("http://example.org")
async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client:
async with httpcore.HTTP2Connection(
reader=server, writer=server, origin=origin
) as client:
response = await client.request("GET", "http://example.org")
assert response.status_code == 200
assert json.loads(response.body) == {"method": "GET", "path": "/", "body": ""}
@ -89,17 +92,25 @@ async def test_http2_get_request():
async def test_http2_post_request():
server = MockServer()
origin = httpcore.Origin("http://example.org")
async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client:
async with httpcore.HTTP2Connection(
reader=server, writer=server, origin=origin
) as client:
response = await client.request("POST", "http://example.org", body=b"<data>")
assert response.status_code == 200
assert json.loads(response.body) == {"method": "POST", "path": "/", "body": "<data>"}
assert json.loads(response.body) == {
"method": "POST",
"path": "/",
"body": "<data>",
}
@pytest.mark.asyncio
async def test_http2_multiple_requests():
server = MockServer()
origin = httpcore.Origin("http://example.org")
async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client:
async with httpcore.HTTP2Connection(
reader=server, writer=server, origin=origin
) as client:
response_1 = await client.request("GET", "http://example.org/1")
response_2 = await client.request("GET", "http://example.org/2")
response_3 = await client.request("GET", "http://example.org/3")