Add PoolSemaphore
This commit is contained in:
parent
9ab607d8a3
commit
286f04f1a6
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
from .config import (
|
||||
@ -13,6 +12,7 @@ from .config import (
|
||||
from .connection import HTTPConnection
|
||||
from .exceptions import PoolTimeout
|
||||
from .models import Client, Origin, Request, Response
|
||||
from .streams import PoolSemaphore
|
||||
|
||||
|
||||
class ConnectionPool(Client):
|
||||
@ -32,9 +32,7 @@ class ConnectionPool(Client):
|
||||
self._keepalive_connections = (
|
||||
{}
|
||||
) # type: typing.Dict[Origin, typing.List[HTTPConnection]]
|
||||
self._max_connections = ConnectionSemaphore(
|
||||
max_connections=self.limits.hard_limit
|
||||
)
|
||||
self._max_connections = PoolSemaphore(limits, timeout)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
@ -62,15 +60,7 @@ class ConnectionPool(Client):
|
||||
self.num_active_connections += 1
|
||||
|
||||
except (KeyError, IndexError):
|
||||
if timeout is None:
|
||||
pool_timeout = self.timeout.pool_timeout
|
||||
else:
|
||||
pool_timeout = timeout.pool_timeout
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self._max_connections.acquire(), pool_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise PoolTimeout()
|
||||
await self._max_connections.acquire(timeout)
|
||||
connection = HTTPConnection(
|
||||
origin,
|
||||
ssl=self.ssl,
|
||||
@ -108,25 +98,3 @@ class ConnectionPool(Client):
|
||||
self._keepalive_connections.clear()
|
||||
for connection in all_connections:
|
||||
await connection.close()
|
||||
|
||||
|
||||
class ConnectionSemaphore:
|
||||
def __init__(self, max_connections: int = None):
|
||||
self.max_connections = max_connections
|
||||
|
||||
@property
|
||||
def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
|
||||
if not hasattr(self, "_semaphore"):
|
||||
if self.max_connections is None:
|
||||
self._semaphore = None
|
||||
else:
|
||||
self._semaphore = asyncio.BoundedSemaphore(value=self.max_connections)
|
||||
return self._semaphore
|
||||
|
||||
async def acquire(self) -> None:
|
||||
if self.semaphore is not None:
|
||||
await self.semaphore.acquire()
|
||||
|
||||
def release(self) -> None:
|
||||
if self.semaphore is not None:
|
||||
self.semaphore.release()
|
||||
|
||||
@ -2,7 +2,9 @@
|
||||
The `Reader` and `Writer` classes here provide a lightweight layer over
|
||||
`asyncio.StreamReader` and `asyncio.StreamWriter`.
|
||||
|
||||
They help encapsulate the timeout logic, make it easier to unit-test
|
||||
Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`.
|
||||
|
||||
These classes help encapsulate the timeout logic, make it easier to unit-test
|
||||
protocols, and help keep the rest of the package more `async`/`await`
|
||||
based, and less strictly `asyncio`-specific.
|
||||
"""
|
||||
@ -11,8 +13,8 @@ import enum
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
from .config import TimeoutConfig, DEFAULT_TIMEOUT_CONFIG
|
||||
from .exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
|
||||
from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
|
||||
from .exceptions import ConnectTimeout, ReadTimeout, PoolTimeout, WriteTimeout
|
||||
|
||||
OptionalTimeout = typing.Optional[TimeoutConfig]
|
||||
|
||||
@ -38,6 +40,17 @@ class BaseWriter:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class BasePoolSemaphore:
|
||||
def __init__(self, limits: PoolLimits, timeout: TimeoutConfig):
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def acquire(self, timeout: OptionalTimeout = None) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def release(self) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class Reader(BaseReader):
|
||||
def __init__(
|
||||
self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
|
||||
@ -86,6 +99,40 @@ class Writer(BaseWriter):
|
||||
self.stream_writer.close()
|
||||
|
||||
|
||||
class PoolSemaphore(BasePoolSemaphore):
|
||||
def __init__(self, limits: PoolLimits, timeout: TimeoutConfig):
|
||||
self.limits = limits
|
||||
self.timeout = timeout
|
||||
|
||||
@property
|
||||
def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
|
||||
if not hasattr(self, "_semaphore"):
|
||||
max_connections = self.limits.hard_limit
|
||||
if max_connections is None:
|
||||
self._semaphore = None
|
||||
else:
|
||||
self._semaphore = asyncio.BoundedSemaphore(value=max_connections)
|
||||
return self._semaphore
|
||||
|
||||
async def acquire(self, timeout: OptionalTimeout = None) -> None:
|
||||
if self.semaphore is None:
|
||||
return
|
||||
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self.semaphore.acquire(), timeout.pool_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise PoolTimeout()
|
||||
|
||||
def release(self) -> None:
|
||||
if self.semaphore is None:
|
||||
return
|
||||
|
||||
self.semaphore.release()
|
||||
|
||||
|
||||
async def connect(
|
||||
hostname: str,
|
||||
port: int,
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import json
|
||||
|
||||
import h2.config
|
||||
import h2.connection
|
||||
import h2.events
|
||||
import pytest
|
||||
import json
|
||||
|
||||
import httpcore
|
||||
|
||||
@ -61,15 +62,15 @@ class MockServer(httpcore.BaseReader, httpcore.BaseWriter):
|
||||
request_headers = dict(request["headers"])
|
||||
request_data = request["data"]
|
||||
|
||||
response_body = json.dumps({
|
||||
"method": request_headers[b":method"].decode(),
|
||||
"path": request_headers[b":path"].decode(),
|
||||
"body": request_data.decode()
|
||||
}).encode()
|
||||
response_body = json.dumps(
|
||||
{
|
||||
"method": request_headers[b":method"].decode(),
|
||||
"path": request_headers[b":path"].decode(),
|
||||
"body": request_data.decode(),
|
||||
}
|
||||
).encode()
|
||||
|
||||
response_headers = (
|
||||
(b":status", b"200"),
|
||||
)
|
||||
response_headers = ((b":status", b"200"),)
|
||||
self.conn.send_headers(stream_id, response_headers)
|
||||
self.conn.send_data(stream_id, response_body, end_stream=True)
|
||||
self.buffer += self.conn.data_to_send()
|
||||
@ -79,7 +80,9 @@ class MockServer(httpcore.BaseReader, httpcore.BaseWriter):
|
||||
async def test_http2_get_request():
|
||||
server = MockServer()
|
||||
origin = httpcore.Origin("http://example.org")
|
||||
async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client:
|
||||
async with httpcore.HTTP2Connection(
|
||||
reader=server, writer=server, origin=origin
|
||||
) as client:
|
||||
response = await client.request("GET", "http://example.org")
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.body) == {"method": "GET", "path": "/", "body": ""}
|
||||
@ -89,17 +92,25 @@ async def test_http2_get_request():
|
||||
async def test_http2_post_request():
|
||||
server = MockServer()
|
||||
origin = httpcore.Origin("http://example.org")
|
||||
async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client:
|
||||
async with httpcore.HTTP2Connection(
|
||||
reader=server, writer=server, origin=origin
|
||||
) as client:
|
||||
response = await client.request("POST", "http://example.org", body=b"<data>")
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.body) == {"method": "POST", "path": "/", "body": "<data>"}
|
||||
assert json.loads(response.body) == {
|
||||
"method": "POST",
|
||||
"path": "/",
|
||||
"body": "<data>",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http2_multiple_requests():
|
||||
server = MockServer()
|
||||
origin = httpcore.Origin("http://example.org")
|
||||
async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client:
|
||||
async with httpcore.HTTP2Connection(
|
||||
reader=server, writer=server, origin=origin
|
||||
) as client:
|
||||
response_1 = await client.request("GET", "http://example.org/1")
|
||||
response_2 = await client.request("GET", "http://example.org/2")
|
||||
response_3 = await client.request("GET", "http://example.org/3")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user