Add SyncConnectionPool
This commit is contained in:
parent
39eb55b2b3
commit
4a36ec74c7
@ -11,5 +11,6 @@ from .exceptions import (
|
||||
Timeout,
|
||||
)
|
||||
from .pool import ConnectionPool
|
||||
from .sync import SyncClient, SyncConnectionPool
|
||||
|
||||
__version__ = "0.1.1"
|
||||
|
||||
51
httpcore/compat.py
Normal file
51
httpcore/compat.py
Normal file
@ -0,0 +1,51 @@
|
||||
import asyncio
|
||||
|
||||
if hasattr(asyncio, "run"):
|
||||
asyncio_run = asyncio.run
|
||||
else: # pragma: nocover
|
||||
|
||||
def asyncio_run(main, *, debug=False): # type: ignore
|
||||
if asyncio._get_running_loop() is not None:
|
||||
raise RuntimeError(
|
||||
"asyncio.run() cannot be called from a running event loop"
|
||||
)
|
||||
|
||||
if not asyncio.iscoroutine(main):
|
||||
raise ValueError("a coroutine was expected, got {!r}".format(main))
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.set_debug(debug)
|
||||
return loop.run_until_complete(main)
|
||||
finally:
|
||||
try:
|
||||
_cancel_all_tasks(loop)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
def _cancel_all_tasks(loop): # type: ignore
|
||||
to_cancel = asyncio.all_tasks(loop)
|
||||
if not to_cancel:
|
||||
return
|
||||
|
||||
for task in to_cancel:
|
||||
task.cancel()
|
||||
|
||||
loop.run_until_complete(
|
||||
tasks.gather(*to_cancel, loop=loop, return_exceptions=True)
|
||||
)
|
||||
|
||||
for task in to_cancel:
|
||||
if task.cancelled():
|
||||
continue
|
||||
if task.exception() is not None:
|
||||
loop.call_exception_handler(
|
||||
{
|
||||
"message": "unhandled exception during asyncio.run() shutdown",
|
||||
"exception": task.exception(),
|
||||
"task": task,
|
||||
}
|
||||
)
|
||||
@ -101,7 +101,7 @@ class Request:
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
):
|
||||
self.method = method
|
||||
self.method = method.upper()
|
||||
self.url = URL(url) if isinstance(url, str) else url
|
||||
self.headers = list(headers)
|
||||
if isinstance(body, bytes):
|
||||
|
||||
@ -106,13 +106,21 @@ class ConnectionPool(Client):
|
||||
|
||||
class ConnectionSemaphore:
|
||||
def __init__(self, max_connections: int = None):
|
||||
if max_connections is not None:
|
||||
self.semaphore = asyncio.BoundedSemaphore(value=max_connections)
|
||||
self.max_connections = max_connections
|
||||
|
||||
@property
|
||||
def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
|
||||
if not hasattr(self, "_semaphore"):
|
||||
if self.max_connections is None:
|
||||
self._semaphore = None
|
||||
else:
|
||||
self._semaphore = asyncio.BoundedSemaphore(value=self.max_connections)
|
||||
return self._semaphore
|
||||
|
||||
async def acquire(self) -> None:
|
||||
if hasattr(self, "semaphore"):
|
||||
if self.semaphore is not None:
|
||||
await self.semaphore.acquire()
|
||||
|
||||
def release(self) -> None:
|
||||
if hasattr(self, "semaphore"):
|
||||
if self.semaphore is not None:
|
||||
self.semaphore.release()
|
||||
|
||||
@ -0,0 +1,79 @@
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from .compat import asyncio_run
|
||||
from .config import SSLConfig, TimeoutConfig
|
||||
from .datastructures import URL, Client, Response
|
||||
from .pool import ConnectionPool
|
||||
|
||||
|
||||
class SyncResponse:
|
||||
def __init__(self, response: Response):
|
||||
self._response = response
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
return self._response.status_code
|
||||
|
||||
@property
|
||||
def reason(self) -> str:
|
||||
return self._response.reason
|
||||
|
||||
@property
|
||||
def headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
|
||||
return self._response.headers
|
||||
|
||||
@property
|
||||
def body(self) -> bytes:
|
||||
return self._response.body
|
||||
|
||||
def read(self) -> bytes:
|
||||
return asyncio_run(self._response.read())
|
||||
|
||||
|
||||
class SyncClient:
|
||||
def __init__(self, client: Client):
|
||||
self._client = client
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
stream: bool = False,
|
||||
) -> SyncResponse:
|
||||
response = asyncio_run(
|
||||
self._client.request(
|
||||
method,
|
||||
url,
|
||||
headers=headers,
|
||||
body=body,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
)
|
||||
return SyncResponse(response)
|
||||
|
||||
def close(self) -> None:
|
||||
asyncio_run(self._client.close())
|
||||
|
||||
def __enter__(self) -> "SyncClient":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: typing.Type[BaseException] = None,
|
||||
exc_value: BaseException = None,
|
||||
traceback: TracebackType = None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
def SyncConnectionPool(*args: typing.Any, **kwargs: typing.Any) -> SyncClient:
|
||||
client = ConnectionPool(*args, **kwargs) # type: ignore
|
||||
return SyncClient(client)
|
||||
38
tests/test_sync.py
Normal file
38
tests/test_sync.py
Normal file
@ -0,0 +1,38 @@
|
||||
import asyncio
|
||||
import functools
|
||||
|
||||
import pytest
|
||||
|
||||
import httpcore
|
||||
|
||||
|
||||
def threadpool(func):
|
||||
"""
|
||||
Our sync tests should run in seperate thread to the uvicorn server.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapped(*args, **kwargs):
|
||||
nonlocal func
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
if kwargs:
|
||||
func = functools.partial(func, **kwargs)
|
||||
await loop.run_in_executor(None, func, *args)
|
||||
|
||||
return pytest.mark.asyncio(wrapped)
|
||||
|
||||
|
||||
@threadpool
|
||||
def test_get(server):
|
||||
with httpcore.SyncConnectionPool() as http:
|
||||
response = http.request("GET", "http://127.0.0.1:8000/")
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"Hello, world!"
|
||||
|
||||
|
||||
@threadpool
|
||||
def test_post(server):
|
||||
with httpcore.SyncConnectionPool() as http:
|
||||
response = http.request("POST", "http://127.0.0.1:8000/", body=b"Hello, world!")
|
||||
assert response.status_code == 200
|
||||
Loading…
Reference in New Issue
Block a user