This commit is contained in:
Vitali Tsimoshka 2026-02-24 07:02:22 +03:00 committed by GitHub
commit 680bc92474
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 284 additions and 6 deletions

28
httpx/_compat.py Normal file
View File

@ -0,0 +1,28 @@
import sys
if sys.version_info >= (3, 10):
from contextlib import aclosing
else:
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar
class _SupportsAclose(Protocol):
def aclose(self) -> Awaitable[object]: ...
_SupportsAcloseT = TypeVar("_SupportsAcloseT", bound=_SupportsAclose)
@asynccontextmanager
async def aclosing(thing: _SupportsAcloseT) -> AsyncIterator[Any]:
try:
yield thing
finally:
await thing.aclose()
if sys.version_info >= (3, 13):
from typing import TypeIs
else:
from typing_extensions import TypeIs
__all__ = ["aclosing", "TypeIs"]

View File

@ -22,8 +22,9 @@ from ._types import (
RequestFiles,
ResponseContent,
SyncByteStream,
is_async_readable_file,
)
from ._utils import peek_filelike_length, primitive_value_to_str
from ._utils import peek_filelike_length, primitive_value_to_str, to_bytes
__all__ = ["ByteStream"]
@ -83,6 +84,11 @@ class AsyncIteratorByteStream(AsyncByteStream):
while chunk:
yield chunk
chunk = await self._stream.aread(self.CHUNK_SIZE)
elif is_async_readable_file(self._stream):
chunk = await self._stream.read(self.CHUNK_SIZE)
while chunk:
yield to_bytes(chunk)
chunk = await self._stream.read(self.CHUNK_SIZE)
else:
# Otherwise iterate.
async for part in self._stream:
@ -127,7 +133,12 @@ def encode_content(
return headers, IteratorByteStream(content) # type: ignore
elif isinstance(content, AsyncIterable):
headers = {"Transfer-Encoding": "chunked"}
if is_async_readable_file(content) and (
content_length_or_none := peek_filelike_length(content)
):
headers = {"Content-Length": str(content_length_or_none)}
else:
headers = {"Transfer-Encoding": "chunked"}
return headers, AsyncIteratorByteStream(content)
raise TypeError(f"Unexpected type for 'content', {type(content)!r}")

View File

@ -7,6 +7,7 @@ import re
import typing
from pathlib import Path
from ._compat import aclosing
from ._types import (
AsyncByteStream,
FileContent,
@ -14,6 +15,7 @@ from ._types import (
RequestData,
RequestFiles,
SyncByteStream,
is_async_readable_file,
)
from ._utils import (
peek_filelike_length,
@ -201,6 +203,11 @@ class FileField:
return self._headers
def render_data(self) -> typing.Iterator[bytes]:
if is_async_readable_file(self.file):
raise TypeError(
"Invalid type for file. AsyncReadableFile is not supported."
)
if isinstance(self.file, (str, bytes)):
yield to_bytes(self.file)
return
@ -216,10 +223,27 @@ class FileField:
yield to_bytes(chunk)
chunk = self.file.read(self.CHUNK_SIZE)
async def arender_data(self) -> typing.AsyncGenerator[bytes]:
if not is_async_readable_file(self.file):
for chunk in self.render_data():
yield chunk
return
await self.file.seek(0)
chunk = await self.file.read(self.CHUNK_SIZE)
while chunk:
yield to_bytes(chunk)
chunk = await self.file.read(self.CHUNK_SIZE)
def render(self) -> typing.Iterator[bytes]:
yield self.render_headers()
yield from self.render_data()
async def arender(self) -> typing.AsyncGenerator[bytes]:
yield self.render_headers()
async with aclosing(self.arender_data()) as data:
async for chunk in data:
yield chunk
class MultipartStream(SyncByteStream, AsyncByteStream):
"""
@ -262,6 +286,19 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
yield b"\r\n"
yield b"--%s--\r\n" % self.boundary
async def aiter_chunks(self) -> typing.AsyncGenerator[bytes]:
for field in self.fields:
yield b"--%s\r\n" % self.boundary
if isinstance(field, FileField):
async with aclosing(field.arender()) as data:
async for chunk in data:
yield chunk
else:
for chunk in field.render():
yield chunk
yield b"\r\n"
yield b"--%s--\r\n" % self.boundary
def get_content_length(self) -> int | None:
"""
Return the length of the multipart encoded content, or `None` if
@ -296,5 +333,6 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
yield chunk
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
for chunk in self.iter_chunks():
yield chunk
async with aclosing(self.aiter_chunks()) as data:
async for chunk in data:
yield chunk

View File

@ -2,11 +2,13 @@
Type definitions for type checking purposes.
"""
import inspect
from http.cookiejar import CookieJar
from typing import (
IO,
TYPE_CHECKING,
Any,
AnyStr,
AsyncIterable,
AsyncIterator,
Callable,
@ -16,11 +18,14 @@ from typing import (
List,
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
Union,
)
from ._compat import TypeIs
if TYPE_CHECKING: # pragma: no cover
from ._auth import Auth # noqa: F401
from ._config import Proxy, Timeout # noqa: F401
@ -71,7 +76,18 @@ ResponseExtensions = Mapping[str, Any]
RequestData = Mapping[str, Any]
FileContent = Union[IO[bytes], bytes, str]
class AsyncReadableFile(Protocol):
async def __aiter__(self) -> AsyncIterator[AnyStr]: ...
async def read(self, size: int = -1) -> AnyStr: ...
def fileno(self) -> int: ...
async def seek(self, offset: int, whence: Optional[int] = ...) -> int: ...
FileContent = Union[IO[bytes], bytes, str, AsyncReadableFile]
FileTypes = Union[
# file (or bytes)
FileContent,
@ -112,3 +128,16 @@ class AsyncByteStream:
async def aclose(self) -> None:
pass
def is_async_readable_file(fp: Any) -> TypeIs[AsyncReadableFile]:
return (
isinstance(fp, AsyncIterable)
and hasattr(fp, "read")
and inspect.iscoroutinefunction(fp.read)
and hasattr(fp, "fileno")
and callable(fp.fileno)
and not inspect.iscoroutinefunction(fp.fileno)
and hasattr(fp, "seek")
and inspect.iscoroutinefunction(fp.seek)
)

View File

@ -128,5 +128,5 @@ markers = [
]
[tool.coverage.run]
omit = ["venv/*"]
omit = ["venv/*", "httpx/_compat.py"]
include = ["httpx/*", "tests/*"]

View File

@ -27,3 +27,5 @@ trio==0.31.0
trio-typing==0.10.0
trustme==1.2.1
uvicorn==0.35.0
aiofiles==25.1.0
types-aiofiles==25.1.0.20251011

View File

@ -1,9 +1,14 @@
import io
import typing
import aiofiles
import anyio
import pytest
import trio
import httpx
from httpx._content import AsyncIteratorByteStream
from httpx._types import AsyncReadableFile, is_async_readable_file
method = "POST"
url = "https://www.example.com"
@ -516,3 +521,71 @@ def test_allow_nan_false():
ValueError, match="Out of range float values are not JSON compliant"
):
httpx.Response(200, json=data_with_inf)
@pytest.mark.parametrize(
"client_method,content_seed,mode",
[
("put", "a🥳", "rt"),
("post", "a🥳", "rt"),
("put", "a🥳", "rb"),
("post", "a🥳", "rb"),
],
ids=["put_text", "post_text", "put_binary", "post_binary"],
)
@pytest.mark.anyio
async def test_chunked_async_file_content(
tmp_path, anyio_backend, monkeypatch, client_method, server, content_seed, mode
):
total_chunks = 3
seed_size = len(content_seed.encode()) if "b" in mode else len(content_seed)
read_calls_expected = total_chunks * seed_size + 1
content = "".join(
[content_seed * AsyncIteratorByteStream.CHUNK_SIZE] * total_chunks
)
content_bytes = content.encode()
to_upload = tmp_path / "upload.txt"
to_upload.write_bytes(content_bytes)
url = server.url.copy_with(path="/echo_body")
async def checks(client: httpx.AsyncClient, async_file: AsyncReadableFile) -> None:
read_called = 0
fileno_called = 0
original_read = async_file.read
original_fileno = async_file.fileno
async def mock_read(*args, **kwargs):
nonlocal read_called
read_called += 1
return await original_read(*args, **kwargs)
def mock_fileno(*args):
nonlocal fileno_called
fileno_called += 1
return original_fileno(*args)
monkeypatch.setattr(async_file, "read", mock_read)
monkeypatch.setattr(async_file, "fileno", mock_fileno)
response = await getattr(client, client_method)(url=url, content=async_file)
assert response.status_code == 200
assert response.content == content_bytes
assert response.request.headers["Content-Length"] == str(len(content_bytes))
assert read_called == read_calls_expected
assert fileno_called == 1
async with (
await anyio.open_file(to_upload, mode=mode)
if anyio_backend != "trio"
else await trio.open_file(to_upload, mode=mode) as async_file,
httpx.AsyncClient() as client,
):
assert is_async_readable_file(async_file)
await checks(client, async_file)
if anyio_backend != "trio": # aiofiles doesn't work with trio
async with (
aiofiles.open(to_upload, mode=mode) as aio_file,
httpx.AsyncClient() as client,
):
assert is_async_readable_file(aio_file)
await checks(client, aio_file)

View File

@ -4,9 +4,13 @@ import io
import tempfile
import typing
import anyio
import pytest
import trio
import httpx
from httpx._multipart import FileField
from httpx._types import AsyncReadableFile, is_async_readable_file
def echo_request_content(request: httpx.Request) -> httpx.Response:
@ -467,3 +471,96 @@ class TestHeaderParamHTML5Formatting:
files = {"upload": (filename, b"<file content>")}
request = httpx.Request("GET", "https://www.example.com", files=files)
assert expected in request.read()
@pytest.mark.parametrize(
"content_seed,mode",
[
("a🥳", "rt"),
("a🥳", "rb"),
],
ids=["text_mode", "binary_mode"],
)
@pytest.mark.anyio
async def test_chunked_async_file_multipart(
tmp_path, anyio_backend, monkeypatch, server, content_seed, mode
):
total_chunks = 3
seed_size = len(content_seed.encode()) if "b" in mode else len(content_seed)
read_calls_expected = total_chunks * seed_size + 1
content = "".join([content_seed * FileField.CHUNK_SIZE] * total_chunks)
content_bytes = content.encode()
to_upload = tmp_path / "upload.txt"
to_upload.write_bytes(content_bytes)
url = server.url.copy_with(path="/echo_body")
async def checks(client: httpx.AsyncClient, async_file: AsyncReadableFile) -> None:
read_called = 0
fileno_called = False
original_read = async_file.read
original_fileno = async_file.fileno
async def mock_read(*args, **kwargs):
nonlocal read_called
read_called += 1
return await original_read(*args, **kwargs)
def mock_fileno(*args):
nonlocal fileno_called
fileno_called = True
return original_fileno(*args)
monkeypatch.setattr(async_file, "read", mock_read)
monkeypatch.setattr(async_file, "fileno", mock_fileno)
response = await client.post(url=url, files={"file": async_file})
assert response.status_code == 200
boundary = response.request.headers["Content-Type"].split("boundary=")[-1]
boundary_bytes = boundary.encode("ascii")
pre_content = b"".join(
[
b"--" + boundary_bytes + b"\r\n",
b'Content-Disposition: form-data; name="file"; '
b'filename="upload.txt"\r\n',
b"Content-Type: text/plain\r\n",
b"\r\n",
]
)
post_content = b"".join(
[
b"\r\n",
b"--" + boundary_bytes + b"--\r\n",
]
)
assert response.content == b"".join(
[
pre_content,
content_bytes,
post_content,
]
)
assert response.request.headers["Content-Length"] == str(
len(pre_content) + len(post_content) + len(content_bytes)
)
assert read_called == read_calls_expected
assert fileno_called
async with (
await anyio.open_file(to_upload, mode=mode)
if anyio_backend != "trio"
else await trio.open_file(to_upload, mode=mode) as async_file,
httpx.AsyncClient() as client,
):
assert is_async_readable_file(async_file)
await checks(client, async_file)
async with (
await anyio.open_file(to_upload, mode=mode)
if anyio_backend != "trio"
else await trio.open_file(to_upload, mode=mode) as async_file,
):
with (
httpx.Client() as sync_client,
pytest.raises(TypeError, match="AsyncReadableFile is not supported"),
):
sync_client.post(url, files={"file": async_file})