From 3aa2ef51be67acfa65ee2c8469a7fbb7d89350c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 26 Jun 2025 17:53:03 +0300 Subject: [PATCH] Ensured the explicit closing of async generators # Conflicts: # CHANGELOG.md --- CHANGELOG.md | 4 +++ httpx/_client.py | 10 ++++--- httpx/_content.py | 14 ++++++---- httpx/_models.py | 53 ++++++++++++++++++------------------ httpx/_multipart.py | 3 +- httpx/_transports/asgi.py | 3 +- httpx/_transports/default.py | 9 ++++-- httpx/_types.py | 4 +-- httpx/_utils.py | 21 ++++++++++++++ tests/test_content.py | 17 +++--------- 10 files changed, 80 insertions(+), 58 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 13bbfcdb..865023a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [UNRELEASED] +### Fixed + +* Explicitly close all async generators to ensure predictable behavior + ### Removed * Drop support for Python 3.8 diff --git a/httpx/_client.py b/httpx/_client.py index 13cd9336..df46a59c 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -6,6 +6,7 @@ import logging import time import typing import warnings +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager, contextmanager from types import TracebackType @@ -46,7 +47,7 @@ from ._types import ( TimeoutTypes, ) from ._urls import URL, QueryParams -from ._utils import URLPattern, get_environment_proxies +from ._utils import URLPattern, get_environment_proxies, safe_async_iterate if typing.TYPE_CHECKING: import ssl # pragma: no cover @@ -172,9 +173,10 @@ class BoundAsyncStream(AsyncByteStream): self._response = response self._start = start - async def __aiter__(self) -> typing.AsyncIterator[bytes]: - async for chunk in self._stream: - yield chunk + async def __aiter__(self) -> AsyncGenerator[bytes]: + async with safe_async_iterate(self._stream) as iterator: + async for chunk in iterator: + yield chunk async def aclose(self) -> None: elapsed = time.perf_counter() - self._start diff --git a/httpx/_content.py b/httpx/_content.py index 6f479a08..f18b371f 100644 --- a/httpx/_content.py +++ b/httpx/_content.py @@ -2,6 +2,7 @@ from __future__ import annotations import inspect import warnings +from collections.abc import AsyncGenerator from json import dumps as json_dumps from typing import ( Any, @@ -10,6 +11,7 @@ from typing import ( Iterable, Iterator, Mapping, + NoReturn, ) from urllib.parse import urlencode @@ -23,7 +25,7 @@ from ._types import ( ResponseContent, SyncByteStream, ) -from ._utils import peek_filelike_length, primitive_value_to_str +from ._utils import peek_filelike_length, primitive_value_to_str, safe_async_iterate __all__ = ["ByteStream"] @@ -35,7 +37,7 @@ class ByteStream(AsyncByteStream, SyncByteStream): def __iter__(self) -> Iterator[bytes]: yield self._stream - async def __aiter__(self) -> AsyncIterator[bytes]: + async def __aiter__(self) -> AsyncGenerator[bytes]: yield self._stream @@ -85,8 +87,9 @@ class AsyncIteratorByteStream(AsyncByteStream): chunk = await self._stream.aread(self.CHUNK_SIZE) else: # Otherwise iterate. - async for part in self._stream: - yield part + async with safe_async_iterate(self._stream) as iterator: + async for part in iterator: + yield part class UnattachedStream(AsyncByteStream, SyncByteStream): @@ -99,9 +102,8 @@ class UnattachedStream(AsyncByteStream, SyncByteStream): def __iter__(self) -> Iterator[bytes]: raise StreamClosed() - async def __aiter__(self) -> AsyncIterator[bytes]: + def __aiter__(self) -> NoReturn: raise StreamClosed() - yield b"" # pragma: no cover def encode_content( diff --git a/httpx/_models.py b/httpx/_models.py index 2cc86321..a9c04dbd 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -7,7 +7,7 @@ import json as jsonlib import re import typing import urllib.request -from collections.abc import Mapping +from collections.abc import AsyncGenerator, Mapping from http.cookiejar import Cookie, CookieJar from ._content import ByteStream, UnattachedStream, encode_request, encode_response @@ -46,7 +46,7 @@ from ._types import ( SyncByteStream, ) from ._urls import URL -from ._utils import to_bytes_or_str, to_str +from ._utils import safe_async_iterate, to_bytes_or_str, to_str __all__ = ["Cookies", "Headers", "Request", "Response"] @@ -979,9 +979,7 @@ class Response: self._content = b"".join([part async for part in self.aiter_bytes()]) return self._content - async def aiter_bytes( - self, chunk_size: int | None = None - ) -> typing.AsyncIterator[bytes]: + async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncGenerator[bytes]: """ A byte-iterator over the decoded response content. This allows us to handle gzip, deflate, brotli, and zstd encoded responses. @@ -994,19 +992,19 @@ class Response: decoder = self._get_content_decoder() chunker = ByteChunker(chunk_size=chunk_size) with request_context(request=self._request): - async for raw_bytes in self.aiter_raw(): - decoded = decoder.decode(raw_bytes) - for chunk in chunker.decode(decoded): - yield chunk + async with safe_async_iterate(self.aiter_raw()) as iterator: + async for raw_bytes in iterator: + decoded = decoder.decode(raw_bytes) + for chunk in chunker.decode(decoded): + yield chunk + decoded = decoder.flush() for chunk in chunker.decode(decoded): yield chunk # pragma: no cover for chunk in chunker.flush(): yield chunk - async def aiter_text( - self, chunk_size: int | None = None - ) -> typing.AsyncIterator[str]: + async def aiter_text(self, chunk_size: int | None = None) -> AsyncGenerator[str]: """ A str-iterator over the decoded response content that handles both gzip, deflate, etc but also detects the content's @@ -1015,28 +1013,28 @@ class Response: decoder = TextDecoder(encoding=self.encoding or "utf-8") chunker = TextChunker(chunk_size=chunk_size) with request_context(request=self._request): - async for byte_content in self.aiter_bytes(): - text_content = decoder.decode(byte_content) - for chunk in chunker.decode(text_content): - yield chunk + async with safe_async_iterate(self.aiter_bytes()) as iterator: + async for byte_content in iterator: + text_content = decoder.decode(byte_content) + for chunk in chunker.decode(text_content): + yield chunk text_content = decoder.flush() for chunk in chunker.decode(text_content): yield chunk # pragma: no cover for chunk in chunker.flush(): yield chunk - async def aiter_lines(self) -> typing.AsyncIterator[str]: + async def aiter_lines(self) -> AsyncGenerator[str]: decoder = LineDecoder() with request_context(request=self._request): - async for text in self.aiter_text(): - for line in decoder.decode(text): - yield line + async with safe_async_iterate(self.aiter_text()) as iterator: + async for text in iterator: + for line in decoder.decode(text): + yield line for line in decoder.flush(): yield line - async def aiter_raw( - self, chunk_size: int | None = None - ) -> typing.AsyncIterator[bytes]: + async def aiter_raw(self, chunk_size: int | None = None) -> AsyncGenerator[bytes]: """ A byte-iterator over the raw response content. """ @@ -1052,10 +1050,11 @@ class Response: chunker = ByteChunker(chunk_size=chunk_size) with request_context(request=self._request): - async for raw_stream_bytes in self.stream: - self._num_bytes_downloaded += len(raw_stream_bytes) - for chunk in chunker.decode(raw_stream_bytes): - yield chunk + async with safe_async_iterate(self.stream) as iterator: + async for raw_stream_bytes in iterator: + self._num_bytes_downloaded += len(raw_stream_bytes) + for chunk in chunker.decode(raw_stream_bytes): + yield chunk for chunk in chunker.flush(): yield chunk diff --git a/httpx/_multipart.py b/httpx/_multipart.py index b4761af9..1e5d522b 100644 --- a/httpx/_multipart.py +++ b/httpx/_multipart.py @@ -5,6 +5,7 @@ import mimetypes import os import re import typing +from collections.abc import AsyncGenerator from pathlib import Path from ._types import ( @@ -295,6 +296,6 @@ class MultipartStream(SyncByteStream, AsyncByteStream): for chunk in self.iter_chunks(): yield chunk - async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async def __aiter__(self) -> AsyncGenerator[bytes]: for chunk in self.iter_chunks(): yield chunk diff --git a/httpx/_transports/asgi.py b/httpx/_transports/asgi.py index 2bc4efae..910e5644 100644 --- a/httpx/_transports/asgi.py +++ b/httpx/_transports/asgi.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing +from collections.abc import AsyncGenerator from .._models import Request, Response from .._types import AsyncByteStream @@ -56,7 +57,7 @@ class ASGIResponseStream(AsyncByteStream): def __init__(self, body: list[bytes]) -> None: self._body = body - async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async def __aiter__(self) -> AsyncGenerator[bytes]: yield b"".join(self._body) diff --git a/httpx/_transports/default.py b/httpx/_transports/default.py index fc8c7097..6242a5cf 100644 --- a/httpx/_transports/default.py +++ b/httpx/_transports/default.py @@ -28,6 +28,7 @@ from __future__ import annotations import contextlib import typing +from collections.abc import AsyncGenerator from types import TracebackType if typing.TYPE_CHECKING: @@ -55,6 +56,7 @@ from .._exceptions import ( from .._models import Request, Response from .._types import AsyncByteStream, CertTypes, ProxyTypes, SyncByteStream from .._urls import URL +from .._utils import safe_async_iterate from .base import AsyncBaseTransport, BaseTransport T = typing.TypeVar("T", bound="HTTPTransport") @@ -266,10 +268,11 @@ class AsyncResponseStream(AsyncByteStream): def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]) -> None: self._httpcore_stream = httpcore_stream - async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async def __aiter__(self) -> AsyncGenerator[bytes]: with map_httpcore_exceptions(): - async for part in self._httpcore_stream: - yield part + async with safe_async_iterate(self._httpcore_stream) as iterator: + async for part in iterator: + yield part async def aclose(self) -> None: if hasattr(self._httpcore_stream, "aclose"): diff --git a/httpx/_types.py b/httpx/_types.py index 704dfdff..44dac6cf 100644 --- a/httpx/_types.py +++ b/httpx/_types.py @@ -94,7 +94,6 @@ class SyncByteStream: raise NotImplementedError( "The '__iter__' method must be implemented." ) # pragma: no cover - yield b"" # pragma: no cover def close(self) -> None: """ @@ -104,11 +103,10 @@ class SyncByteStream: class AsyncByteStream: - async def __aiter__(self) -> AsyncIterator[bytes]: + def __aiter__(self) -> AsyncIterator[bytes]: raise NotImplementedError( "The '__aiter__' method must be implemented." ) # pragma: no cover - yield b"" # pragma: no cover async def aclose(self) -> None: pass diff --git a/httpx/_utils.py b/httpx/_utils.py index 7fe827da..b6acaba1 100644 --- a/httpx/_utils.py +++ b/httpx/_utils.py @@ -4,6 +4,9 @@ import ipaddress import os import re import typing +from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator +from contextlib import asynccontextmanager +from inspect import isasyncgen from urllib.request import getproxies from ._types import PrimitiveData @@ -11,6 +14,8 @@ from ._types import PrimitiveData if typing.TYPE_CHECKING: # pragma: no cover from ._urls import URL +T = typing.TypeVar("T") + def primitive_value_to_str(value: PrimitiveData) -> str: """ @@ -240,3 +245,19 @@ def is_ipv6_hostname(hostname: str) -> bool: except Exception: return False return True + + +@asynccontextmanager +async def safe_async_iterate( + iterable_or_iterator: AsyncIterable[T] | AsyncIterator[T], / +) -> AsyncGenerator[AsyncIterator[T]]: + iterator = ( + iterable_or_iterator + if isinstance(iterable_or_iterator, AsyncIterator) + else iterable_or_iterator.__aiter__() + ) + try: + yield iterator + finally: + if isasyncgen(iterator): + await iterator.aclose() diff --git a/tests/test_content.py b/tests/test_content.py index f63ec18a..b43b65a6 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -1,5 +1,6 @@ import io import typing +from collections.abc import AsyncGenerator import pytest @@ -64,20 +65,10 @@ async def test_bytesio_content(): @pytest.mark.anyio async def test_async_bytesio_content(): - class AsyncBytesIO: - def __init__(self, content: bytes) -> None: - self._idx = 0 - self._content = content + async def fixed_stream(content: bytes) -> AsyncGenerator[bytes]: + yield content - async def aread(self, chunk_size: int) -> bytes: - chunk = self._content[self._idx : self._idx + chunk_size] - self._idx = self._idx + chunk_size - return chunk - - async def __aiter__(self): - yield self._content # pragma: no cover - - request = httpx.Request(method, url, content=AsyncBytesIO(b"Hello, world!")) + request = httpx.Request(method, url, content=fixed_stream(b"Hello, world!")) assert not isinstance(request.stream, typing.Iterable) assert isinstance(request.stream, typing.AsyncIterable)