Add SyncConnectionPool

This commit is contained in:
Tom Christie 2019-04-23 11:12:37 +01:00
parent 39eb55b2b3
commit 4a36ec74c7
6 changed files with 182 additions and 5 deletions

View File

@ -11,5 +11,6 @@ from .exceptions import (
Timeout,
)
from .pool import ConnectionPool
from .sync import SyncClient, SyncConnectionPool
__version__ = "0.1.1"

51
httpcore/compat.py Normal file
View File

@ -0,0 +1,51 @@
import asyncio
if hasattr(asyncio, "run"):
asyncio_run = asyncio.run
else: # pragma: nocover
def asyncio_run(main, *, debug=False): # type: ignore
if asyncio._get_running_loop() is not None:
raise RuntimeError(
"asyncio.run() cannot be called from a running event loop"
)
if not asyncio.iscoroutine(main):
raise ValueError("a coroutine was expected, got {!r}".format(main))
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
loop.set_debug(debug)
return loop.run_until_complete(main)
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
asyncio.set_event_loop(None)
loop.close()
def _cancel_all_tasks(loop): # type: ignore
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return
for task in to_cancel:
task.cancel()
loop.run_until_complete(
tasks.gather(*to_cancel, loop=loop, return_exceptions=True)
)
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)

View File

@ -101,7 +101,7 @@ class Request:
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
):
self.method = method
self.method = method.upper()
self.url = URL(url) if isinstance(url, str) else url
self.headers = list(headers)
if isinstance(body, bytes):

View File

@ -106,13 +106,21 @@ class ConnectionPool(Client):
class ConnectionSemaphore:
def __init__(self, max_connections: int = None):
if max_connections is not None:
self.semaphore = asyncio.BoundedSemaphore(value=max_connections)
self.max_connections = max_connections
@property
def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
if not hasattr(self, "_semaphore"):
if self.max_connections is None:
self._semaphore = None
else:
self._semaphore = asyncio.BoundedSemaphore(value=self.max_connections)
return self._semaphore
async def acquire(self) -> None:
if hasattr(self, "semaphore"):
if self.semaphore is not None:
await self.semaphore.acquire()
def release(self) -> None:
if hasattr(self, "semaphore"):
if self.semaphore is not None:
self.semaphore.release()

View File

@ -0,0 +1,79 @@
import typing
from types import TracebackType
from .compat import asyncio_run
from .config import SSLConfig, TimeoutConfig
from .datastructures import URL, Client, Response
from .pool import ConnectionPool
class SyncResponse:
def __init__(self, response: Response):
self._response = response
@property
def status_code(self) -> int:
return self._response.status_code
@property
def reason(self) -> str:
return self._response.reason
@property
def headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
return self._response.headers
@property
def body(self) -> bytes:
return self._response.body
def read(self) -> bytes:
return asyncio_run(self._response.read())
class SyncClient:
def __init__(self, client: Client):
self._client = client
def request(
self,
method: str,
url: typing.Union[str, URL],
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
stream: bool = False,
) -> SyncResponse:
response = asyncio_run(
self._client.request(
method,
url,
headers=headers,
body=body,
ssl=ssl,
timeout=timeout,
stream=stream,
)
)
return SyncResponse(response)
def close(self) -> None:
asyncio_run(self._client.close())
def __enter__(self) -> "SyncClient":
return self
def __exit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
self.close()
def SyncConnectionPool(*args: typing.Any, **kwargs: typing.Any) -> SyncClient:
client = ConnectionPool(*args, **kwargs) # type: ignore
return SyncClient(client)

38
tests/test_sync.py Normal file
View File

@ -0,0 +1,38 @@
import asyncio
import functools
import pytest
import httpcore
def threadpool(func):
"""
Our sync tests should run in seperate thread to the uvicorn server.
"""
@functools.wraps(func)
async def wrapped(*args, **kwargs):
nonlocal func
loop = asyncio.get_event_loop()
if kwargs:
func = functools.partial(func, **kwargs)
await loop.run_in_executor(None, func, *args)
return pytest.mark.asyncio(wrapped)
@threadpool
def test_get(server):
with httpcore.SyncConnectionPool() as http:
response = http.request("GET", "http://127.0.0.1:8000/")
assert response.status_code == 200
assert response.body == b"Hello, world!"
@threadpool
def test_post(server):
with httpcore.SyncConnectionPool() as http:
response = http.request("POST", "http://127.0.0.1:8000/", body=b"Hello, world!")
assert response.status_code == 200