Merge pull request #17 from encode/close-keepalive-connections
Close outstanding connections on pool.close()
This commit is contained in:
commit
91a2a1b896
19
.travis.yml
Normal file
19
.travis.yml
Normal file
@ -0,0 +1,19 @@
|
||||
sudo: required
|
||||
dist: xenial
|
||||
language: python
|
||||
|
||||
cache: pip
|
||||
|
||||
python:
|
||||
- "3.6"
|
||||
- "3.7"
|
||||
|
||||
install:
|
||||
- pip install -U -r requirements.txt
|
||||
|
||||
script:
|
||||
- scripts/test
|
||||
|
||||
after_script:
|
||||
- pip install codecov
|
||||
- codecov
|
||||
@ -2,9 +2,9 @@ from .config import PoolLimits, SSLConfig, TimeoutConfig
|
||||
from .connectionpool import ConnectionPool
|
||||
from .datastructures import URL, Origin, Request, Response
|
||||
from .exceptions import (
|
||||
BadResponse,
|
||||
ConnectTimeout,
|
||||
PoolTimeout,
|
||||
ProtocolError,
|
||||
ReadTimeout,
|
||||
ResponseClosed,
|
||||
StreamConsumed,
|
||||
@ -13,4 +13,4 @@ from .exceptions import (
|
||||
from .http11 import HTTP11Connection
|
||||
from .sync import SyncClient, SyncConnectionPool
|
||||
|
||||
__version__ = "0.2.0"
|
||||
__version__ = "0.2.1"
|
||||
|
||||
@ -1,52 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
if hasattr(asyncio, "run"):
|
||||
asyncio_run = asyncio.run
|
||||
|
||||
else: # pragma: nocover
|
||||
|
||||
def asyncio_run(main, *, debug=False): # type: ignore
|
||||
if asyncio._get_running_loop() is not None:
|
||||
raise RuntimeError(
|
||||
"asyncio.run() cannot be called from a running event loop"
|
||||
)
|
||||
|
||||
if not asyncio.iscoroutine(main):
|
||||
raise ValueError("a coroutine was expected, got {!r}".format(main))
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.set_debug(debug)
|
||||
return loop.run_until_complete(main)
|
||||
finally:
|
||||
try:
|
||||
_cancel_all_tasks(loop)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
def _cancel_all_tasks(loop): # type: ignore
|
||||
to_cancel = asyncio.all_tasks(loop)
|
||||
if not to_cancel:
|
||||
return
|
||||
|
||||
for task in to_cancel:
|
||||
task.cancel()
|
||||
|
||||
loop.run_until_complete(
|
||||
tasks.gather(*to_cancel, loop=loop, return_exceptions=True)
|
||||
)
|
||||
|
||||
for task in to_cancel:
|
||||
if task.cancelled():
|
||||
continue
|
||||
if task.exception() is not None:
|
||||
loop.call_exception_handler(
|
||||
{
|
||||
"message": "unhandled exception during asyncio.run() shutdown",
|
||||
"exception": task.exception(),
|
||||
"task": task,
|
||||
}
|
||||
)
|
||||
@ -10,9 +10,9 @@ from .config import (
|
||||
SSLConfig,
|
||||
TimeoutConfig,
|
||||
)
|
||||
from .http11 import HTTP11Connection
|
||||
from .datastructures import Client, Origin, Request, Response
|
||||
from .exceptions import PoolTimeout
|
||||
from .http11 import HTTP11Connection
|
||||
|
||||
|
||||
class ConnectionPool(Client):
|
||||
@ -102,6 +102,12 @@ class ConnectionPool(Client):
|
||||
|
||||
async def close(self) -> None:
|
||||
self.is_closed = True
|
||||
all_connections = []
|
||||
for connections in self._keepalive_connections.values():
|
||||
all_connections.extend(list(connections))
|
||||
self._keepalive_connections.clear()
|
||||
for connection in all_connections:
|
||||
await connection.close()
|
||||
|
||||
|
||||
class ConnectionSemaphore:
|
||||
|
||||
@ -22,9 +22,9 @@ class PoolTimeout(Timeout):
|
||||
"""
|
||||
|
||||
|
||||
class BadResponse(Exception):
|
||||
class ProtocolError(Exception):
|
||||
"""
|
||||
A malformed HTTP response.
|
||||
Malformed HTTP.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
import asyncio
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from .compat import asyncio_run
|
||||
from .config import SSLConfig, TimeoutConfig
|
||||
from .connectionpool import ConnectionPool
|
||||
from .datastructures import URL, Client, Response
|
||||
|
||||
|
||||
class SyncResponse:
|
||||
def __init__(self, response: Response):
|
||||
def __init__(self, response: Response, loop: asyncio.AbstractEventLoop):
|
||||
self._response = response
|
||||
self._loop = loop
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
@ -28,23 +29,24 @@ class SyncResponse:
|
||||
return self._response.body
|
||||
|
||||
def read(self) -> bytes:
|
||||
return asyncio_run(self._response.read())
|
||||
return self._loop.run_until_complete(self._response.read())
|
||||
|
||||
def stream(self) -> typing.Iterator[bytes]:
|
||||
inner = self._response.stream()
|
||||
while True:
|
||||
try:
|
||||
yield asyncio_run(inner.__anext__())
|
||||
yield self._loop.run_until_complete(inner.__anext__())
|
||||
except StopAsyncIteration as exc:
|
||||
break
|
||||
|
||||
def close(self) -> None:
|
||||
return asyncio_run(self._response.close())
|
||||
return self._loop.run_until_complete(self._response.close())
|
||||
|
||||
|
||||
class SyncClient:
|
||||
def __init__(self, client: Client):
|
||||
self._client = client
|
||||
self._loop = asyncio.new_event_loop()
|
||||
|
||||
def request(
|
||||
self,
|
||||
@ -57,7 +59,7 @@ class SyncClient:
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
stream: bool = False,
|
||||
) -> SyncResponse:
|
||||
response = asyncio_run(
|
||||
response = self._loop.run_until_complete(
|
||||
self._client.request(
|
||||
method,
|
||||
url,
|
||||
@ -68,10 +70,10 @@ class SyncClient:
|
||||
stream=stream,
|
||||
)
|
||||
)
|
||||
return SyncResponse(response)
|
||||
return SyncResponse(response, self._loop)
|
||||
|
||||
def close(self) -> None:
|
||||
asyncio_run(self._client.close())
|
||||
self._loop.run_until_complete(self._client.close())
|
||||
|
||||
def __enter__(self) -> "SyncClient":
|
||||
return self
|
||||
|
||||
@ -8,7 +8,6 @@ brotlipy
|
||||
# Testing
|
||||
autoflake
|
||||
black
|
||||
codecov
|
||||
isort
|
||||
mypy
|
||||
pytest
|
||||
|
||||
Loading…
Reference in New Issue
Block a user