Ensured the explicit closing of async generators

# Conflicts:
#	CHANGELOG.md
This commit is contained in:
Alex Grönholm 2025-06-26 17:53:03 +03:00
parent 4fb9528c2f
commit 3aa2ef51be
10 changed files with 80 additions and 58 deletions

View File

@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [UNRELEASED]
### Fixed
* Explicitly close all async generators to ensure predictable behavior
### Removed
* Drop support for Python 3.8

View File

@ -6,6 +6,7 @@ import logging
import time
import typing
import warnings
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager, contextmanager
from types import TracebackType
@ -46,7 +47,7 @@ from ._types import (
TimeoutTypes,
)
from ._urls import URL, QueryParams
from ._utils import URLPattern, get_environment_proxies
from ._utils import URLPattern, get_environment_proxies, safe_async_iterate
if typing.TYPE_CHECKING:
import ssl # pragma: no cover
@ -172,9 +173,10 @@ class BoundAsyncStream(AsyncByteStream):
self._response = response
self._start = start
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async for chunk in self._stream:
yield chunk
async def __aiter__(self) -> AsyncGenerator[bytes]:
async with safe_async_iterate(self._stream) as iterator:
async for chunk in iterator:
yield chunk
async def aclose(self) -> None:
elapsed = time.perf_counter() - self._start

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import inspect
import warnings
from collections.abc import AsyncGenerator
from json import dumps as json_dumps
from typing import (
Any,
@ -10,6 +11,7 @@ from typing import (
Iterable,
Iterator,
Mapping,
NoReturn,
)
from urllib.parse import urlencode
@ -23,7 +25,7 @@ from ._types import (
ResponseContent,
SyncByteStream,
)
from ._utils import peek_filelike_length, primitive_value_to_str
from ._utils import peek_filelike_length, primitive_value_to_str, safe_async_iterate
__all__ = ["ByteStream"]
@ -35,7 +37,7 @@ class ByteStream(AsyncByteStream, SyncByteStream):
def __iter__(self) -> Iterator[bytes]:
yield self._stream
async def __aiter__(self) -> AsyncIterator[bytes]:
async def __aiter__(self) -> AsyncGenerator[bytes]:
yield self._stream
@ -85,8 +87,9 @@ class AsyncIteratorByteStream(AsyncByteStream):
chunk = await self._stream.aread(self.CHUNK_SIZE)
else:
# Otherwise iterate.
async for part in self._stream:
yield part
async with safe_async_iterate(self._stream) as iterator:
async for part in iterator:
yield part
class UnattachedStream(AsyncByteStream, SyncByteStream):
@ -99,9 +102,8 @@ class UnattachedStream(AsyncByteStream, SyncByteStream):
def __iter__(self) -> Iterator[bytes]:
raise StreamClosed()
async def __aiter__(self) -> AsyncIterator[bytes]:
def __aiter__(self) -> NoReturn:
raise StreamClosed()
yield b"" # pragma: no cover
def encode_content(

View File

@ -7,7 +7,7 @@ import json as jsonlib
import re
import typing
import urllib.request
from collections.abc import Mapping
from collections.abc import AsyncGenerator, Mapping
from http.cookiejar import Cookie, CookieJar
from ._content import ByteStream, UnattachedStream, encode_request, encode_response
@ -46,7 +46,7 @@ from ._types import (
SyncByteStream,
)
from ._urls import URL
from ._utils import to_bytes_or_str, to_str
from ._utils import safe_async_iterate, to_bytes_or_str, to_str
__all__ = ["Cookies", "Headers", "Request", "Response"]
@ -979,9 +979,7 @@ class Response:
self._content = b"".join([part async for part in self.aiter_bytes()])
return self._content
async def aiter_bytes(
self, chunk_size: int | None = None
) -> typing.AsyncIterator[bytes]:
async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncGenerator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, brotli, and zstd encoded responses.
@ -994,19 +992,19 @@ class Response:
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with request_context(request=self._request):
async for raw_bytes in self.aiter_raw():
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
yield chunk
async with safe_async_iterate(self.aiter_raw()) as iterator:
async for raw_bytes in iterator:
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
yield chunk
decoded = decoder.flush()
for chunk in chunker.decode(decoded):
yield chunk # pragma: no cover
for chunk in chunker.flush():
yield chunk
async def aiter_text(
self, chunk_size: int | None = None
) -> typing.AsyncIterator[str]:
async def aiter_text(self, chunk_size: int | None = None) -> AsyncGenerator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
@ -1015,28 +1013,28 @@ class Response:
decoder = TextDecoder(encoding=self.encoding or "utf-8")
chunker = TextChunker(chunk_size=chunk_size)
with request_context(request=self._request):
async for byte_content in self.aiter_bytes():
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
yield chunk
async with safe_async_iterate(self.aiter_bytes()) as iterator:
async for byte_content in iterator:
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
yield chunk
text_content = decoder.flush()
for chunk in chunker.decode(text_content):
yield chunk # pragma: no cover
for chunk in chunker.flush():
yield chunk
async def aiter_lines(self) -> typing.AsyncIterator[str]:
async def aiter_lines(self) -> AsyncGenerator[str]:
decoder = LineDecoder()
with request_context(request=self._request):
async for text in self.aiter_text():
for line in decoder.decode(text):
yield line
async with safe_async_iterate(self.aiter_text()) as iterator:
async for text in iterator:
for line in decoder.decode(text):
yield line
for line in decoder.flush():
yield line
async def aiter_raw(
self, chunk_size: int | None = None
) -> typing.AsyncIterator[bytes]:
async def aiter_raw(self, chunk_size: int | None = None) -> AsyncGenerator[bytes]:
"""
A byte-iterator over the raw response content.
"""
@ -1052,10 +1050,11 @@ class Response:
chunker = ByteChunker(chunk_size=chunk_size)
with request_context(request=self._request):
async for raw_stream_bytes in self.stream:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
yield chunk
async with safe_async_iterate(self.stream) as iterator:
async for raw_stream_bytes in iterator:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
yield chunk
for chunk in chunker.flush():
yield chunk

View File

@ -5,6 +5,7 @@ import mimetypes
import os
import re
import typing
from collections.abc import AsyncGenerator
from pathlib import Path
from ._types import (
@ -295,6 +296,6 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
for chunk in self.iter_chunks():
yield chunk
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async def __aiter__(self) -> AsyncGenerator[bytes]:
for chunk in self.iter_chunks():
yield chunk

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import typing
from collections.abc import AsyncGenerator
from .._models import Request, Response
from .._types import AsyncByteStream
@ -56,7 +57,7 @@ class ASGIResponseStream(AsyncByteStream):
def __init__(self, body: list[bytes]) -> None:
self._body = body
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async def __aiter__(self) -> AsyncGenerator[bytes]:
yield b"".join(self._body)

View File

@ -28,6 +28,7 @@ from __future__ import annotations
import contextlib
import typing
from collections.abc import AsyncGenerator
from types import TracebackType
if typing.TYPE_CHECKING:
@ -55,6 +56,7 @@ from .._exceptions import (
from .._models import Request, Response
from .._types import AsyncByteStream, CertTypes, ProxyTypes, SyncByteStream
from .._urls import URL
from .._utils import safe_async_iterate
from .base import AsyncBaseTransport, BaseTransport
T = typing.TypeVar("T", bound="HTTPTransport")
@ -266,10 +268,11 @@ class AsyncResponseStream(AsyncByteStream):
def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]) -> None:
self._httpcore_stream = httpcore_stream
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async def __aiter__(self) -> AsyncGenerator[bytes]:
with map_httpcore_exceptions():
async for part in self._httpcore_stream:
yield part
async with safe_async_iterate(self._httpcore_stream) as iterator:
async for part in iterator:
yield part
async def aclose(self) -> None:
if hasattr(self._httpcore_stream, "aclose"):

View File

@ -94,7 +94,6 @@ class SyncByteStream:
raise NotImplementedError(
"The '__iter__' method must be implemented."
) # pragma: no cover
yield b"" # pragma: no cover
def close(self) -> None:
"""
@ -104,11 +103,10 @@ class SyncByteStream:
class AsyncByteStream:
async def __aiter__(self) -> AsyncIterator[bytes]:
def __aiter__(self) -> AsyncIterator[bytes]:
raise NotImplementedError(
"The '__aiter__' method must be implemented."
) # pragma: no cover
yield b"" # pragma: no cover
async def aclose(self) -> None:
pass

View File

@ -4,6 +4,9 @@ import ipaddress
import os
import re
import typing
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
from contextlib import asynccontextmanager
from inspect import isasyncgen
from urllib.request import getproxies
from ._types import PrimitiveData
@ -11,6 +14,8 @@ from ._types import PrimitiveData
if typing.TYPE_CHECKING: # pragma: no cover
from ._urls import URL
T = typing.TypeVar("T")
def primitive_value_to_str(value: PrimitiveData) -> str:
"""
@ -240,3 +245,19 @@ def is_ipv6_hostname(hostname: str) -> bool:
except Exception:
return False
return True
@asynccontextmanager
async def safe_async_iterate(
iterable_or_iterator: AsyncIterable[T] | AsyncIterator[T], /
) -> AsyncGenerator[AsyncIterator[T]]:
iterator = (
iterable_or_iterator
if isinstance(iterable_or_iterator, AsyncIterator)
else iterable_or_iterator.__aiter__()
)
try:
yield iterator
finally:
if isasyncgen(iterator):
await iterator.aclose()

View File

@ -1,5 +1,6 @@
import io
import typing
from collections.abc import AsyncGenerator
import pytest
@ -64,20 +65,10 @@ async def test_bytesio_content():
@pytest.mark.anyio
async def test_async_bytesio_content():
class AsyncBytesIO:
def __init__(self, content: bytes) -> None:
self._idx = 0
self._content = content
async def fixed_stream(content: bytes) -> AsyncGenerator[bytes]:
yield content
async def aread(self, chunk_size: int) -> bytes:
chunk = self._content[self._idx : self._idx + chunk_size]
self._idx = self._idx + chunk_size
return chunk
async def __aiter__(self):
yield self._content # pragma: no cover
request = httpx.Request(method, url, content=AsyncBytesIO(b"Hello, world!"))
request = httpx.Request(method, url, content=fixed_stream(b"Hello, world!"))
assert not isinstance(request.stream, typing.Iterable)
assert isinstance(request.stream, typing.AsyncIterable)