Ensured the explicit closing of async generators
# Conflicts: # CHANGELOG.md
This commit is contained in:
parent
4fb9528c2f
commit
3aa2ef51be
@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
|
||||
|
||||
## [UNRELEASED]
|
||||
|
||||
### Fixed
|
||||
|
||||
* Explicitly close all async generators to ensure predictable behavior
|
||||
|
||||
### Removed
|
||||
|
||||
* Drop support for Python 3.8
|
||||
|
||||
@ -6,6 +6,7 @@ import logging
|
||||
import time
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from types import TracebackType
|
||||
|
||||
@ -46,7 +47,7 @@ from ._types import (
|
||||
TimeoutTypes,
|
||||
)
|
||||
from ._urls import URL, QueryParams
|
||||
from ._utils import URLPattern, get_environment_proxies
|
||||
from ._utils import URLPattern, get_environment_proxies, safe_async_iterate
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import ssl # pragma: no cover
|
||||
@ -172,9 +173,10 @@ class BoundAsyncStream(AsyncByteStream):
|
||||
self._response = response
|
||||
self._start = start
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
async for chunk in self._stream:
|
||||
yield chunk
|
||||
async def __aiter__(self) -> AsyncGenerator[bytes]:
|
||||
async with safe_async_iterate(self._stream) as iterator:
|
||||
async for chunk in iterator:
|
||||
yield chunk
|
||||
|
||||
async def aclose(self) -> None:
|
||||
elapsed = time.perf_counter() - self._start
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from json import dumps as json_dumps
|
||||
from typing import (
|
||||
Any,
|
||||
@ -10,6 +11,7 @@ from typing import (
|
||||
Iterable,
|
||||
Iterator,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
)
|
||||
from urllib.parse import urlencode
|
||||
|
||||
@ -23,7 +25,7 @@ from ._types import (
|
||||
ResponseContent,
|
||||
SyncByteStream,
|
||||
)
|
||||
from ._utils import peek_filelike_length, primitive_value_to_str
|
||||
from ._utils import peek_filelike_length, primitive_value_to_str, safe_async_iterate
|
||||
|
||||
__all__ = ["ByteStream"]
|
||||
|
||||
@ -35,7 +37,7 @@ class ByteStream(AsyncByteStream, SyncByteStream):
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
yield self._stream
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[bytes]:
|
||||
async def __aiter__(self) -> AsyncGenerator[bytes]:
|
||||
yield self._stream
|
||||
|
||||
|
||||
@ -85,8 +87,9 @@ class AsyncIteratorByteStream(AsyncByteStream):
|
||||
chunk = await self._stream.aread(self.CHUNK_SIZE)
|
||||
else:
|
||||
# Otherwise iterate.
|
||||
async for part in self._stream:
|
||||
yield part
|
||||
async with safe_async_iterate(self._stream) as iterator:
|
||||
async for part in iterator:
|
||||
yield part
|
||||
|
||||
|
||||
class UnattachedStream(AsyncByteStream, SyncByteStream):
|
||||
@ -99,9 +102,8 @@ class UnattachedStream(AsyncByteStream, SyncByteStream):
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
raise StreamClosed()
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[bytes]:
|
||||
def __aiter__(self) -> NoReturn:
|
||||
raise StreamClosed()
|
||||
yield b"" # pragma: no cover
|
||||
|
||||
|
||||
def encode_content(
|
||||
|
||||
@ -7,7 +7,7 @@ import json as jsonlib
|
||||
import re
|
||||
import typing
|
||||
import urllib.request
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from http.cookiejar import Cookie, CookieJar
|
||||
|
||||
from ._content import ByteStream, UnattachedStream, encode_request, encode_response
|
||||
@ -46,7 +46,7 @@ from ._types import (
|
||||
SyncByteStream,
|
||||
)
|
||||
from ._urls import URL
|
||||
from ._utils import to_bytes_or_str, to_str
|
||||
from ._utils import safe_async_iterate, to_bytes_or_str, to_str
|
||||
|
||||
__all__ = ["Cookies", "Headers", "Request", "Response"]
|
||||
|
||||
@ -979,9 +979,7 @@ class Response:
|
||||
self._content = b"".join([part async for part in self.aiter_bytes()])
|
||||
return self._content
|
||||
|
||||
async def aiter_bytes(
|
||||
self, chunk_size: int | None = None
|
||||
) -> typing.AsyncIterator[bytes]:
|
||||
async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncGenerator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the decoded response content.
|
||||
This allows us to handle gzip, deflate, brotli, and zstd encoded responses.
|
||||
@ -994,19 +992,19 @@ class Response:
|
||||
decoder = self._get_content_decoder()
|
||||
chunker = ByteChunker(chunk_size=chunk_size)
|
||||
with request_context(request=self._request):
|
||||
async for raw_bytes in self.aiter_raw():
|
||||
decoded = decoder.decode(raw_bytes)
|
||||
for chunk in chunker.decode(decoded):
|
||||
yield chunk
|
||||
async with safe_async_iterate(self.aiter_raw()) as iterator:
|
||||
async for raw_bytes in iterator:
|
||||
decoded = decoder.decode(raw_bytes)
|
||||
for chunk in chunker.decode(decoded):
|
||||
yield chunk
|
||||
|
||||
decoded = decoder.flush()
|
||||
for chunk in chunker.decode(decoded):
|
||||
yield chunk # pragma: no cover
|
||||
for chunk in chunker.flush():
|
||||
yield chunk
|
||||
|
||||
async def aiter_text(
|
||||
self, chunk_size: int | None = None
|
||||
) -> typing.AsyncIterator[str]:
|
||||
async def aiter_text(self, chunk_size: int | None = None) -> AsyncGenerator[str]:
|
||||
"""
|
||||
A str-iterator over the decoded response content
|
||||
that handles both gzip, deflate, etc but also detects the content's
|
||||
@ -1015,28 +1013,28 @@ class Response:
|
||||
decoder = TextDecoder(encoding=self.encoding or "utf-8")
|
||||
chunker = TextChunker(chunk_size=chunk_size)
|
||||
with request_context(request=self._request):
|
||||
async for byte_content in self.aiter_bytes():
|
||||
text_content = decoder.decode(byte_content)
|
||||
for chunk in chunker.decode(text_content):
|
||||
yield chunk
|
||||
async with safe_async_iterate(self.aiter_bytes()) as iterator:
|
||||
async for byte_content in iterator:
|
||||
text_content = decoder.decode(byte_content)
|
||||
for chunk in chunker.decode(text_content):
|
||||
yield chunk
|
||||
text_content = decoder.flush()
|
||||
for chunk in chunker.decode(text_content):
|
||||
yield chunk # pragma: no cover
|
||||
for chunk in chunker.flush():
|
||||
yield chunk
|
||||
|
||||
async def aiter_lines(self) -> typing.AsyncIterator[str]:
|
||||
async def aiter_lines(self) -> AsyncGenerator[str]:
|
||||
decoder = LineDecoder()
|
||||
with request_context(request=self._request):
|
||||
async for text in self.aiter_text():
|
||||
for line in decoder.decode(text):
|
||||
yield line
|
||||
async with safe_async_iterate(self.aiter_text()) as iterator:
|
||||
async for text in iterator:
|
||||
for line in decoder.decode(text):
|
||||
yield line
|
||||
for line in decoder.flush():
|
||||
yield line
|
||||
|
||||
async def aiter_raw(
|
||||
self, chunk_size: int | None = None
|
||||
) -> typing.AsyncIterator[bytes]:
|
||||
async def aiter_raw(self, chunk_size: int | None = None) -> AsyncGenerator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the raw response content.
|
||||
"""
|
||||
@ -1052,10 +1050,11 @@ class Response:
|
||||
chunker = ByteChunker(chunk_size=chunk_size)
|
||||
|
||||
with request_context(request=self._request):
|
||||
async for raw_stream_bytes in self.stream:
|
||||
self._num_bytes_downloaded += len(raw_stream_bytes)
|
||||
for chunk in chunker.decode(raw_stream_bytes):
|
||||
yield chunk
|
||||
async with safe_async_iterate(self.stream) as iterator:
|
||||
async for raw_stream_bytes in iterator:
|
||||
self._num_bytes_downloaded += len(raw_stream_bytes)
|
||||
for chunk in chunker.decode(raw_stream_bytes):
|
||||
yield chunk
|
||||
|
||||
for chunk in chunker.flush():
|
||||
yield chunk
|
||||
|
||||
@ -5,6 +5,7 @@ import mimetypes
|
||||
import os
|
||||
import re
|
||||
import typing
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
|
||||
from ._types import (
|
||||
@ -295,6 +296,6 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
|
||||
for chunk in self.iter_chunks():
|
||||
yield chunk
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
async def __aiter__(self) -> AsyncGenerator[bytes]:
|
||||
for chunk in self.iter_chunks():
|
||||
yield chunk
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from .._models import Request, Response
|
||||
from .._types import AsyncByteStream
|
||||
@ -56,7 +57,7 @@ class ASGIResponseStream(AsyncByteStream):
|
||||
def __init__(self, body: list[bytes]) -> None:
|
||||
self._body = body
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
async def __aiter__(self) -> AsyncGenerator[bytes]:
|
||||
yield b"".join(self._body)
|
||||
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import typing
|
||||
from collections.abc import AsyncGenerator
|
||||
from types import TracebackType
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
@ -55,6 +56,7 @@ from .._exceptions import (
|
||||
from .._models import Request, Response
|
||||
from .._types import AsyncByteStream, CertTypes, ProxyTypes, SyncByteStream
|
||||
from .._urls import URL
|
||||
from .._utils import safe_async_iterate
|
||||
from .base import AsyncBaseTransport, BaseTransport
|
||||
|
||||
T = typing.TypeVar("T", bound="HTTPTransport")
|
||||
@ -266,10 +268,11 @@ class AsyncResponseStream(AsyncByteStream):
|
||||
def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]) -> None:
|
||||
self._httpcore_stream = httpcore_stream
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
async def __aiter__(self) -> AsyncGenerator[bytes]:
|
||||
with map_httpcore_exceptions():
|
||||
async for part in self._httpcore_stream:
|
||||
yield part
|
||||
async with safe_async_iterate(self._httpcore_stream) as iterator:
|
||||
async for part in iterator:
|
||||
yield part
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if hasattr(self._httpcore_stream, "aclose"):
|
||||
|
||||
@ -94,7 +94,6 @@ class SyncByteStream:
|
||||
raise NotImplementedError(
|
||||
"The '__iter__' method must be implemented."
|
||||
) # pragma: no cover
|
||||
yield b"" # pragma: no cover
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
@ -104,11 +103,10 @@ class SyncByteStream:
|
||||
|
||||
|
||||
class AsyncByteStream:
|
||||
async def __aiter__(self) -> AsyncIterator[bytes]:
|
||||
def __aiter__(self) -> AsyncIterator[bytes]:
|
||||
raise NotImplementedError(
|
||||
"The '__aiter__' method must be implemented."
|
||||
) # pragma: no cover
|
||||
yield b"" # pragma: no cover
|
||||
|
||||
async def aclose(self) -> None:
|
||||
pass
|
||||
|
||||
@ -4,6 +4,9 @@ import ipaddress
|
||||
import os
|
||||
import re
|
||||
import typing
|
||||
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from inspect import isasyncgen
|
||||
from urllib.request import getproxies
|
||||
|
||||
from ._types import PrimitiveData
|
||||
@ -11,6 +14,8 @@ from ._types import PrimitiveData
|
||||
if typing.TYPE_CHECKING: # pragma: no cover
|
||||
from ._urls import URL
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
|
||||
|
||||
def primitive_value_to_str(value: PrimitiveData) -> str:
|
||||
"""
|
||||
@ -240,3 +245,19 @@ def is_ipv6_hostname(hostname: str) -> bool:
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def safe_async_iterate(
|
||||
iterable_or_iterator: AsyncIterable[T] | AsyncIterator[T], /
|
||||
) -> AsyncGenerator[AsyncIterator[T]]:
|
||||
iterator = (
|
||||
iterable_or_iterator
|
||||
if isinstance(iterable_or_iterator, AsyncIterator)
|
||||
else iterable_or_iterator.__aiter__()
|
||||
)
|
||||
try:
|
||||
yield iterator
|
||||
finally:
|
||||
if isasyncgen(iterator):
|
||||
await iterator.aclose()
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import io
|
||||
import typing
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
|
||||
@ -64,20 +65,10 @@ async def test_bytesio_content():
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_bytesio_content():
|
||||
class AsyncBytesIO:
|
||||
def __init__(self, content: bytes) -> None:
|
||||
self._idx = 0
|
||||
self._content = content
|
||||
async def fixed_stream(content: bytes) -> AsyncGenerator[bytes]:
|
||||
yield content
|
||||
|
||||
async def aread(self, chunk_size: int) -> bytes:
|
||||
chunk = self._content[self._idx : self._idx + chunk_size]
|
||||
self._idx = self._idx + chunk_size
|
||||
return chunk
|
||||
|
||||
async def __aiter__(self):
|
||||
yield self._content # pragma: no cover
|
||||
|
||||
request = httpx.Request(method, url, content=AsyncBytesIO(b"Hello, world!"))
|
||||
request = httpx.Request(method, url, content=fixed_stream(b"Hello, world!"))
|
||||
assert not isinstance(request.stream, typing.Iterable)
|
||||
assert isinstance(request.stream, typing.AsyncIterable)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user