Merge 79494c514a into b5addb64f0
This commit is contained in:
commit
680bc92474
28
httpx/_compat.py
Normal file
28
httpx/_compat.py
Normal file
@ -0,0 +1,28 @@
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from contextlib import aclosing
|
||||
else:
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar
|
||||
|
||||
class _SupportsAclose(Protocol):
|
||||
def aclose(self) -> Awaitable[object]: ...
|
||||
|
||||
_SupportsAcloseT = TypeVar("_SupportsAcloseT", bound=_SupportsAclose)
|
||||
|
||||
@asynccontextmanager
|
||||
async def aclosing(thing: _SupportsAcloseT) -> AsyncIterator[Any]:
|
||||
try:
|
||||
yield thing
|
||||
finally:
|
||||
await thing.aclose()
|
||||
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeIs
|
||||
else:
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
|
||||
__all__ = ["aclosing", "TypeIs"]
|
||||
@ -22,8 +22,9 @@ from ._types import (
|
||||
RequestFiles,
|
||||
ResponseContent,
|
||||
SyncByteStream,
|
||||
is_async_readable_file,
|
||||
)
|
||||
from ._utils import peek_filelike_length, primitive_value_to_str
|
||||
from ._utils import peek_filelike_length, primitive_value_to_str, to_bytes
|
||||
|
||||
__all__ = ["ByteStream"]
|
||||
|
||||
@ -83,6 +84,11 @@ class AsyncIteratorByteStream(AsyncByteStream):
|
||||
while chunk:
|
||||
yield chunk
|
||||
chunk = await self._stream.aread(self.CHUNK_SIZE)
|
||||
elif is_async_readable_file(self._stream):
|
||||
chunk = await self._stream.read(self.CHUNK_SIZE)
|
||||
while chunk:
|
||||
yield to_bytes(chunk)
|
||||
chunk = await self._stream.read(self.CHUNK_SIZE)
|
||||
else:
|
||||
# Otherwise iterate.
|
||||
async for part in self._stream:
|
||||
@ -127,7 +133,12 @@ def encode_content(
|
||||
return headers, IteratorByteStream(content) # type: ignore
|
||||
|
||||
elif isinstance(content, AsyncIterable):
|
||||
headers = {"Transfer-Encoding": "chunked"}
|
||||
if is_async_readable_file(content) and (
|
||||
content_length_or_none := peek_filelike_length(content)
|
||||
):
|
||||
headers = {"Content-Length": str(content_length_or_none)}
|
||||
else:
|
||||
headers = {"Transfer-Encoding": "chunked"}
|
||||
return headers, AsyncIteratorByteStream(content)
|
||||
|
||||
raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
|
||||
|
||||
@ -7,6 +7,7 @@ import re
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
from ._compat import aclosing
|
||||
from ._types import (
|
||||
AsyncByteStream,
|
||||
FileContent,
|
||||
@ -14,6 +15,7 @@ from ._types import (
|
||||
RequestData,
|
||||
RequestFiles,
|
||||
SyncByteStream,
|
||||
is_async_readable_file,
|
||||
)
|
||||
from ._utils import (
|
||||
peek_filelike_length,
|
||||
@ -201,6 +203,11 @@ class FileField:
|
||||
return self._headers
|
||||
|
||||
def render_data(self) -> typing.Iterator[bytes]:
|
||||
if is_async_readable_file(self.file):
|
||||
raise TypeError(
|
||||
"Invalid type for file. AsyncReadableFile is not supported."
|
||||
)
|
||||
|
||||
if isinstance(self.file, (str, bytes)):
|
||||
yield to_bytes(self.file)
|
||||
return
|
||||
@ -216,10 +223,27 @@ class FileField:
|
||||
yield to_bytes(chunk)
|
||||
chunk = self.file.read(self.CHUNK_SIZE)
|
||||
|
||||
async def arender_data(self) -> typing.AsyncGenerator[bytes]:
|
||||
if not is_async_readable_file(self.file):
|
||||
for chunk in self.render_data():
|
||||
yield chunk
|
||||
return
|
||||
await self.file.seek(0)
|
||||
chunk = await self.file.read(self.CHUNK_SIZE)
|
||||
while chunk:
|
||||
yield to_bytes(chunk)
|
||||
chunk = await self.file.read(self.CHUNK_SIZE)
|
||||
|
||||
def render(self) -> typing.Iterator[bytes]:
|
||||
yield self.render_headers()
|
||||
yield from self.render_data()
|
||||
|
||||
async def arender(self) -> typing.AsyncGenerator[bytes]:
|
||||
yield self.render_headers()
|
||||
async with aclosing(self.arender_data()) as data:
|
||||
async for chunk in data:
|
||||
yield chunk
|
||||
|
||||
|
||||
class MultipartStream(SyncByteStream, AsyncByteStream):
|
||||
"""
|
||||
@ -262,6 +286,19 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
|
||||
yield b"\r\n"
|
||||
yield b"--%s--\r\n" % self.boundary
|
||||
|
||||
async def aiter_chunks(self) -> typing.AsyncGenerator[bytes]:
|
||||
for field in self.fields:
|
||||
yield b"--%s\r\n" % self.boundary
|
||||
if isinstance(field, FileField):
|
||||
async with aclosing(field.arender()) as data:
|
||||
async for chunk in data:
|
||||
yield chunk
|
||||
else:
|
||||
for chunk in field.render():
|
||||
yield chunk
|
||||
yield b"\r\n"
|
||||
yield b"--%s--\r\n" % self.boundary
|
||||
|
||||
def get_content_length(self) -> int | None:
|
||||
"""
|
||||
Return the length of the multipart encoded content, or `None` if
|
||||
@ -296,5 +333,6 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
|
||||
yield chunk
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
for chunk in self.iter_chunks():
|
||||
yield chunk
|
||||
async with aclosing(self.aiter_chunks()) as data:
|
||||
async for chunk in data:
|
||||
yield chunk
|
||||
|
||||
@ -2,11 +2,13 @@
|
||||
Type definitions for type checking purposes.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from http.cookiejar import CookieJar
|
||||
from typing import (
|
||||
IO,
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AnyStr,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
@ -16,11 +18,14 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from ._compat import TypeIs
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ._auth import Auth # noqa: F401
|
||||
from ._config import Proxy, Timeout # noqa: F401
|
||||
@ -71,7 +76,18 @@ ResponseExtensions = Mapping[str, Any]
|
||||
|
||||
RequestData = Mapping[str, Any]
|
||||
|
||||
FileContent = Union[IO[bytes], bytes, str]
|
||||
|
||||
class AsyncReadableFile(Protocol):
|
||||
async def __aiter__(self) -> AsyncIterator[AnyStr]: ...
|
||||
|
||||
async def read(self, size: int = -1) -> AnyStr: ...
|
||||
|
||||
def fileno(self) -> int: ...
|
||||
|
||||
async def seek(self, offset: int, whence: Optional[int] = ...) -> int: ...
|
||||
|
||||
|
||||
FileContent = Union[IO[bytes], bytes, str, AsyncReadableFile]
|
||||
FileTypes = Union[
|
||||
# file (or bytes)
|
||||
FileContent,
|
||||
@ -112,3 +128,16 @@ class AsyncByteStream:
|
||||
|
||||
async def aclose(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def is_async_readable_file(fp: Any) -> TypeIs[AsyncReadableFile]:
|
||||
return (
|
||||
isinstance(fp, AsyncIterable)
|
||||
and hasattr(fp, "read")
|
||||
and inspect.iscoroutinefunction(fp.read)
|
||||
and hasattr(fp, "fileno")
|
||||
and callable(fp.fileno)
|
||||
and not inspect.iscoroutinefunction(fp.fileno)
|
||||
and hasattr(fp, "seek")
|
||||
and inspect.iscoroutinefunction(fp.seek)
|
||||
)
|
||||
|
||||
@ -128,5 +128,5 @@ markers = [
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["venv/*"]
|
||||
omit = ["venv/*", "httpx/_compat.py"]
|
||||
include = ["httpx/*", "tests/*"]
|
||||
|
||||
@ -27,3 +27,5 @@ trio==0.31.0
|
||||
trio-typing==0.10.0
|
||||
trustme==1.2.1
|
||||
uvicorn==0.35.0
|
||||
aiofiles==25.1.0
|
||||
types-aiofiles==25.1.0.20251011
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
import io
|
||||
import typing
|
||||
|
||||
import aiofiles
|
||||
import anyio
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
import httpx
|
||||
from httpx._content import AsyncIteratorByteStream
|
||||
from httpx._types import AsyncReadableFile, is_async_readable_file
|
||||
|
||||
method = "POST"
|
||||
url = "https://www.example.com"
|
||||
@ -516,3 +521,71 @@ def test_allow_nan_false():
|
||||
ValueError, match="Out of range float values are not JSON compliant"
|
||||
):
|
||||
httpx.Response(200, json=data_with_inf)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"client_method,content_seed,mode",
|
||||
[
|
||||
("put", "a🥳", "rt"),
|
||||
("post", "a🥳", "rt"),
|
||||
("put", "a🥳", "rb"),
|
||||
("post", "a🥳", "rb"),
|
||||
],
|
||||
ids=["put_text", "post_text", "put_binary", "post_binary"],
|
||||
)
|
||||
@pytest.mark.anyio
|
||||
async def test_chunked_async_file_content(
|
||||
tmp_path, anyio_backend, monkeypatch, client_method, server, content_seed, mode
|
||||
):
|
||||
total_chunks = 3
|
||||
seed_size = len(content_seed.encode()) if "b" in mode else len(content_seed)
|
||||
read_calls_expected = total_chunks * seed_size + 1
|
||||
content = "".join(
|
||||
[content_seed * AsyncIteratorByteStream.CHUNK_SIZE] * total_chunks
|
||||
)
|
||||
content_bytes = content.encode()
|
||||
to_upload = tmp_path / "upload.txt"
|
||||
to_upload.write_bytes(content_bytes)
|
||||
url = server.url.copy_with(path="/echo_body")
|
||||
|
||||
async def checks(client: httpx.AsyncClient, async_file: AsyncReadableFile) -> None:
|
||||
read_called = 0
|
||||
fileno_called = 0
|
||||
original_read = async_file.read
|
||||
original_fileno = async_file.fileno
|
||||
|
||||
async def mock_read(*args, **kwargs):
|
||||
nonlocal read_called
|
||||
read_called += 1
|
||||
return await original_read(*args, **kwargs)
|
||||
|
||||
def mock_fileno(*args):
|
||||
nonlocal fileno_called
|
||||
fileno_called += 1
|
||||
return original_fileno(*args)
|
||||
|
||||
monkeypatch.setattr(async_file, "read", mock_read)
|
||||
monkeypatch.setattr(async_file, "fileno", mock_fileno)
|
||||
response = await getattr(client, client_method)(url=url, content=async_file)
|
||||
assert response.status_code == 200
|
||||
assert response.content == content_bytes
|
||||
assert response.request.headers["Content-Length"] == str(len(content_bytes))
|
||||
assert read_called == read_calls_expected
|
||||
assert fileno_called == 1
|
||||
|
||||
async with (
|
||||
await anyio.open_file(to_upload, mode=mode)
|
||||
if anyio_backend != "trio"
|
||||
else await trio.open_file(to_upload, mode=mode) as async_file,
|
||||
httpx.AsyncClient() as client,
|
||||
):
|
||||
assert is_async_readable_file(async_file)
|
||||
await checks(client, async_file)
|
||||
|
||||
if anyio_backend != "trio": # aiofiles doesn't work with trio
|
||||
async with (
|
||||
aiofiles.open(to_upload, mode=mode) as aio_file,
|
||||
httpx.AsyncClient() as client,
|
||||
):
|
||||
assert is_async_readable_file(aio_file)
|
||||
await checks(client, aio_file)
|
||||
|
||||
@ -4,9 +4,13 @@ import io
|
||||
import tempfile
|
||||
import typing
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
import httpx
|
||||
from httpx._multipart import FileField
|
||||
from httpx._types import AsyncReadableFile, is_async_readable_file
|
||||
|
||||
|
||||
def echo_request_content(request: httpx.Request) -> httpx.Response:
|
||||
@ -467,3 +471,96 @@ class TestHeaderParamHTML5Formatting:
|
||||
files = {"upload": (filename, b"<file content>")}
|
||||
request = httpx.Request("GET", "https://www.example.com", files=files)
|
||||
assert expected in request.read()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content_seed,mode",
|
||||
[
|
||||
("a🥳", "rt"),
|
||||
("a🥳", "rb"),
|
||||
],
|
||||
ids=["text_mode", "binary_mode"],
|
||||
)
|
||||
@pytest.mark.anyio
|
||||
async def test_chunked_async_file_multipart(
|
||||
tmp_path, anyio_backend, monkeypatch, server, content_seed, mode
|
||||
):
|
||||
total_chunks = 3
|
||||
seed_size = len(content_seed.encode()) if "b" in mode else len(content_seed)
|
||||
read_calls_expected = total_chunks * seed_size + 1
|
||||
content = "".join([content_seed * FileField.CHUNK_SIZE] * total_chunks)
|
||||
content_bytes = content.encode()
|
||||
to_upload = tmp_path / "upload.txt"
|
||||
to_upload.write_bytes(content_bytes)
|
||||
url = server.url.copy_with(path="/echo_body")
|
||||
|
||||
async def checks(client: httpx.AsyncClient, async_file: AsyncReadableFile) -> None:
|
||||
read_called = 0
|
||||
fileno_called = False
|
||||
original_read = async_file.read
|
||||
original_fileno = async_file.fileno
|
||||
|
||||
async def mock_read(*args, **kwargs):
|
||||
nonlocal read_called
|
||||
read_called += 1
|
||||
return await original_read(*args, **kwargs)
|
||||
|
||||
def mock_fileno(*args):
|
||||
nonlocal fileno_called
|
||||
fileno_called = True
|
||||
return original_fileno(*args)
|
||||
|
||||
monkeypatch.setattr(async_file, "read", mock_read)
|
||||
monkeypatch.setattr(async_file, "fileno", mock_fileno)
|
||||
response = await client.post(url=url, files={"file": async_file})
|
||||
assert response.status_code == 200
|
||||
boundary = response.request.headers["Content-Type"].split("boundary=")[-1]
|
||||
boundary_bytes = boundary.encode("ascii")
|
||||
pre_content = b"".join(
|
||||
[
|
||||
b"--" + boundary_bytes + b"\r\n",
|
||||
b'Content-Disposition: form-data; name="file"; '
|
||||
b'filename="upload.txt"\r\n',
|
||||
b"Content-Type: text/plain\r\n",
|
||||
b"\r\n",
|
||||
]
|
||||
)
|
||||
post_content = b"".join(
|
||||
[
|
||||
b"\r\n",
|
||||
b"--" + boundary_bytes + b"--\r\n",
|
||||
]
|
||||
)
|
||||
assert response.content == b"".join(
|
||||
[
|
||||
pre_content,
|
||||
content_bytes,
|
||||
post_content,
|
||||
]
|
||||
)
|
||||
assert response.request.headers["Content-Length"] == str(
|
||||
len(pre_content) + len(post_content) + len(content_bytes)
|
||||
)
|
||||
assert read_called == read_calls_expected
|
||||
assert fileno_called
|
||||
|
||||
async with (
|
||||
await anyio.open_file(to_upload, mode=mode)
|
||||
if anyio_backend != "trio"
|
||||
else await trio.open_file(to_upload, mode=mode) as async_file,
|
||||
httpx.AsyncClient() as client,
|
||||
):
|
||||
assert is_async_readable_file(async_file)
|
||||
|
||||
await checks(client, async_file)
|
||||
|
||||
async with (
|
||||
await anyio.open_file(to_upload, mode=mode)
|
||||
if anyio_backend != "trio"
|
||||
else await trio.open_file(to_upload, mode=mode) as async_file,
|
||||
):
|
||||
with (
|
||||
httpx.Client() as sync_client,
|
||||
pytest.raises(TypeError, match="AsyncReadableFile is not supported"),
|
||||
):
|
||||
sync_client.post(url, files={"file": async_file})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user