Drop 'Response(on_close=...)' from API (#1572)

This commit is contained in:
Tom Christie 2021-04-16 10:03:37 +01:00 committed by GitHub
parent 4870cb5adf
commit 073a3284ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 21 deletions

View File

@ -86,6 +86,52 @@ class ClientState(enum.Enum):
CLOSED = 3
class BoundSyncStream(SyncByteStream):
"""
A byte stream that is bound to a given response instance, and that
ensures the `response.elapsed` is set once the response is closed.
"""
def __init__(
self, stream: SyncByteStream, response: Response, timer: Timer
) -> None:
self._stream = stream
self._response = response
self._timer = timer
def __iter__(self) -> typing.Iterator[bytes]:
for chunk in self._stream:
yield chunk
def close(self) -> None:
seconds = self._timer.sync_elapsed()
self._response.elapsed = datetime.timedelta(seconds=seconds)
self._stream.close()
class BoundAsyncStream(AsyncByteStream):
"""
An async byte stream that is bound to a given response instance, and that
ensures the `response.elapsed` is set once the response is closed.
"""
def __init__(
self, stream: AsyncByteStream, response: Response, timer: Timer
) -> None:
self._stream = stream
self._response = response
self._timer = timer
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async for chunk in self._stream:
yield chunk
async def aclose(self) -> None:
seconds = await self._timer.async_elapsed()
self._response.elapsed = datetime.timedelta(seconds=seconds)
await self._stream.aclose()
class BaseClient:
def __init__(
self,
@ -874,28 +920,29 @@ class Client(BaseClient):
timer = Timer()
timer.sync_start()
if not isinstance(request.stream, SyncByteStream):
raise RuntimeError(
"Attempted to send an async request with a sync Client instance."
)
with request_context(request=request):
(status_code, headers, stream, extensions) = transport.handle_request(
request.method.encode(),
request.url.raw,
headers=request.headers.raw,
stream=request.stream, # type: ignore
stream=request.stream,
extensions={"timeout": timeout.as_dict()},
)
def on_close(response: Response) -> None:
response.elapsed = datetime.timedelta(seconds=timer.sync_elapsed())
stream.close()
response = Response(
status_code,
headers=headers,
stream=stream,
extensions=extensions,
request=request,
on_close=on_close,
)
response.stream = BoundSyncStream(stream, response=response, timer=timer)
self.cookies.extract_cookies(response)
status = f"{response.status_code} {response.reason_phrase}"
@ -1512,6 +1559,11 @@ class AsyncClient(BaseClient):
timer = Timer()
await timer.async_start()
if not isinstance(request.stream, AsyncByteStream):
raise RuntimeError(
"Attempted to send an sync request with an AsyncClient instance."
)
with request_context(request=request):
(
status_code,
@ -1522,23 +1574,19 @@ class AsyncClient(BaseClient):
request.method.encode(),
request.url.raw,
headers=request.headers.raw,
stream=request.stream, # type: ignore
stream=request.stream,
extensions={"timeout": timeout.as_dict()},
)
async def on_close(response: Response) -> None:
response.elapsed = datetime.timedelta(seconds=await timer.async_elapsed())
await stream.aclose()
response = Response(
status_code,
headers=headers,
stream=stream,
extensions=extensions,
request=request,
on_close=on_close,
)
response.stream = BoundAsyncStream(stream, response=response, timer=timer)
self.cookies.extract_cookies(response)
status = f"{response.status_code} {response.reason_phrase}"

View File

@ -908,7 +908,6 @@ class Response:
request: Request = None,
extensions: dict = None,
history: typing.List["Response"] = None,
on_close: typing.Callable = None,
):
self.status_code = status_code
self.headers = Headers(headers)
@ -923,7 +922,6 @@ class Response:
self.extensions = {} if extensions is None else extensions
self.history = [] if history is None else list(history)
self._on_close = on_close
self.is_closed = False
self.is_stream_consumed = False
@ -1245,11 +1243,13 @@ class Response:
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
if not isinstance(self.stream, SyncByteStream):
raise RuntimeError("Attempted to call an sync close on an async stream.")
if not self.is_closed:
self.is_closed = True
if self._on_close is not None:
with request_context(request=self._request):
self._on_close(self)
with request_context(request=self._request):
self.stream.close()
async def aread(self) -> bytes:
"""
@ -1341,11 +1341,13 @@ class Response:
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
if not isinstance(self.stream, AsyncByteStream):
raise RuntimeError("Attempted to call an async close on an sync stream.")
if not self.is_closed:
self.is_closed = True
if self._on_close is not None:
with request_context(request=self._request):
await self._on_close(self)
with request_context(request=self._request):
await self.stream.aclose()
class Cookies(MutableMapping):

View File

@ -94,10 +94,21 @@ async def test_stream_request(server):
yield b"world!"
async with httpx.AsyncClient() as client:
response = await client.request("POST", server.url, content=hello_world())
response = await client.post(server.url, content=hello_world())
assert response.status_code == 200
@pytest.mark.usefixtures("async_environment")
async def test_cannot_stream_sync_request(server):
def hello_world(): # pragma: nocover
yield b"Hello, "
yield b"world!"
async with httpx.AsyncClient() as client:
with pytest.raises(RuntimeError):
await client.post(server.url, content=hello_world())
@pytest.mark.usefixtures("async_environment")
async def test_raise_for_status(server):
async with httpx.AsyncClient() as client:

View File

@ -114,6 +114,16 @@ def test_raw_iterator(server):
assert body == b"Hello, world!"
def test_cannot_stream_async_request(server):
async def hello_world(): # pragma: nocover
yield b"Hello, "
yield b"world!"
with httpx.Client() as client:
with pytest.raises(RuntimeError):
client.post(server.url, content=hello_world())
def test_raise_for_status(server):
with httpx.Client() as client:
for status_code in (200, 400, 404, 500, 505):

View File

@ -382,6 +382,16 @@ def test_iter_raw_on_async():
[part for part in response.iter_raw()]
def test_close_on_async():
response = httpx.Response(
200,
content=async_streaming_body(),
)
with pytest.raises(RuntimeError):
response.close()
def test_iter_raw_increments_updates_counter():
response = httpx.Response(200, content=streaming_body())
@ -430,6 +440,17 @@ async def test_aiter_raw_on_sync():
[part async for part in response.aiter_raw()]
@pytest.mark.asyncio
async def test_aclose_on_sync():
response = httpx.Response(
200,
content=streaming_body(),
)
with pytest.raises(RuntimeError):
await response.aclose()
@pytest.mark.asyncio
async def test_aiter_raw_increments_updates_counter():
response = httpx.Response(200, content=async_streaming_body())