parent
e1f7791e97
commit
fbb21fb1ae
@ -74,7 +74,6 @@ except httpx.HTTPStatusError as exc:
|
||||
* UnsupportedProtocol
|
||||
* DecodingError
|
||||
* TooManyRedirects
|
||||
* RequestBodyUnavailable
|
||||
* HTTPStatusError
|
||||
* InvalidURL
|
||||
* NotRedirectResponse
|
||||
@ -149,9 +148,6 @@ except httpx.HTTPStatusError as exc:
|
||||
::: httpx.TooManyRedirects
|
||||
:docstring:
|
||||
|
||||
::: httpx.RequestBodyUnavailable
|
||||
:docstring:
|
||||
|
||||
::: httpx.HTTPStatusError
|
||||
:docstring:
|
||||
|
||||
|
||||
@ -21,7 +21,6 @@ from ._exceptions import (
|
||||
ReadError,
|
||||
ReadTimeout,
|
||||
RemoteProtocolError,
|
||||
RequestBodyUnavailable,
|
||||
RequestError,
|
||||
RequestNotRead,
|
||||
ResponseClosed,
|
||||
@ -84,7 +83,6 @@ __all__ = [
|
||||
"RemoteProtocolError",
|
||||
"request",
|
||||
"Request",
|
||||
"RequestBodyUnavailable",
|
||||
"RequestError",
|
||||
"RequestNotRead",
|
||||
"Response",
|
||||
|
||||
@ -6,7 +6,7 @@ import typing
|
||||
from base64 import b64encode
|
||||
from urllib.request import parse_http_list
|
||||
|
||||
from ._exceptions import ProtocolError, RequestBodyUnavailable
|
||||
from ._exceptions import ProtocolError
|
||||
from ._models import Request, Response
|
||||
from ._utils import to_bytes, to_str, unquote
|
||||
|
||||
@ -157,13 +157,6 @@ class DigestAuth(Auth):
|
||||
self._password = to_bytes(password)
|
||||
|
||||
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
|
||||
if not request.stream.can_replay():
|
||||
raise RequestBodyUnavailable(
|
||||
"Cannot use digest auth with streaming requests that are unable "
|
||||
"to replay the request body if a second request is required.",
|
||||
request=request,
|
||||
)
|
||||
|
||||
response = yield request
|
||||
|
||||
if response.status_code != 401 or "www-authenticate" not in response.headers:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import datetime
|
||||
import functools
|
||||
import typing
|
||||
import warnings
|
||||
@ -18,13 +19,11 @@ from ._config import (
|
||||
UnsetType,
|
||||
create_ssl_context,
|
||||
)
|
||||
from ._content_streams import ContentStream
|
||||
from ._decoders import SUPPORTED_DECODERS
|
||||
from ._exceptions import (
|
||||
HTTPCORE_EXC_MAP,
|
||||
InvalidURL,
|
||||
RemoteProtocolError,
|
||||
RequestBodyUnavailable,
|
||||
TooManyRedirects,
|
||||
map_exceptions,
|
||||
)
|
||||
@ -34,6 +33,7 @@ from ._transports.asgi import ASGITransport
|
||||
from ._transports.wsgi import WSGITransport
|
||||
from ._types import (
|
||||
AuthTypes,
|
||||
ByteStream,
|
||||
CertTypes,
|
||||
CookieTypes,
|
||||
HeaderTypes,
|
||||
@ -480,20 +480,13 @@ class BaseClient:
|
||||
|
||||
def _redirect_stream(
|
||||
self, request: Request, method: str
|
||||
) -> typing.Optional[ContentStream]:
|
||||
) -> typing.Optional[ByteStream]:
|
||||
"""
|
||||
Return the body that should be used for the redirect request.
|
||||
"""
|
||||
if method != request.method and method == "GET":
|
||||
return None
|
||||
|
||||
if not request.stream.can_replay():
|
||||
raise RequestBodyUnavailable(
|
||||
"Got a redirect response, but the request body was streaming "
|
||||
"and is no longer available.",
|
||||
request=request,
|
||||
)
|
||||
|
||||
return request.stream
|
||||
|
||||
|
||||
@ -864,16 +857,22 @@ class Client(BaseClient):
|
||||
request.method.encode(),
|
||||
request.url.raw,
|
||||
headers=request.headers.raw,
|
||||
stream=request.stream,
|
||||
stream=request.stream, # type: ignore
|
||||
timeout=timeout.as_dict(),
|
||||
)
|
||||
|
||||
def on_close(response: Response) -> None:
|
||||
response.elapsed = datetime.timedelta(timer.sync_elapsed())
|
||||
if hasattr(stream, "close"):
|
||||
stream.close()
|
||||
|
||||
response = Response(
|
||||
status_code,
|
||||
http_version=http_version.decode("ascii"),
|
||||
headers=headers,
|
||||
stream=stream, # type: ignore
|
||||
request=request,
|
||||
elapsed_func=timer.sync_elapsed,
|
||||
on_close=on_close,
|
||||
)
|
||||
|
||||
self.cookies.extract_cookies(response)
|
||||
@ -1509,16 +1508,22 @@ class AsyncClient(BaseClient):
|
||||
request.method.encode(),
|
||||
request.url.raw,
|
||||
headers=request.headers.raw,
|
||||
stream=request.stream,
|
||||
stream=request.stream, # type: ignore
|
||||
timeout=timeout.as_dict(),
|
||||
)
|
||||
|
||||
async def on_close(response: Response) -> None:
|
||||
response.elapsed = datetime.timedelta(await timer.async_elapsed())
|
||||
if hasattr(stream, "close"):
|
||||
await stream.aclose()
|
||||
|
||||
response = Response(
|
||||
status_code,
|
||||
http_version=http_version.decode("ascii"),
|
||||
headers=headers,
|
||||
stream=stream, # type: ignore
|
||||
request=request,
|
||||
elapsed_func=timer.async_elapsed,
|
||||
on_close=on_close,
|
||||
)
|
||||
|
||||
self.cookies.extract_cookies(response)
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
import binascii
|
||||
import inspect
|
||||
import os
|
||||
import typing
|
||||
from json import dumps as json_dumps
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpcore
|
||||
|
||||
from ._exceptions import StreamConsumed
|
||||
from ._types import (
|
||||
ByteStream,
|
||||
FileContent,
|
||||
FileTypes,
|
||||
RequestContent,
|
||||
@ -24,36 +24,7 @@ from ._utils import (
|
||||
)
|
||||
|
||||
|
||||
class ContentStream(httpcore.AsyncByteStream, httpcore.SyncByteStream):
|
||||
def get_headers(self) -> typing.Dict[str, str]:
|
||||
"""
|
||||
Return a dictionary of headers that are implied by the encoding.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def can_replay(self) -> bool:
|
||||
"""
|
||||
Return `True` if `__aiter__` can be called multiple times.
|
||||
|
||||
We need this in cases such determining if we can re-issue a request
|
||||
body when we receive a redirect response.
|
||||
"""
|
||||
return True
|
||||
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
yield b""
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
yield b""
|
||||
|
||||
async def aclose(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ByteStream(ContentStream):
|
||||
class PlainByteStream:
|
||||
"""
|
||||
Request content encoded as plain bytes.
|
||||
"""
|
||||
@ -74,59 +45,41 @@ class ByteStream(ContentStream):
|
||||
yield self.body
|
||||
|
||||
|
||||
class IteratorStream(ContentStream):
|
||||
class GeneratorStream:
|
||||
"""
|
||||
Request content encoded as plain bytes, using an byte iterator.
|
||||
Request content encoded as plain bytes, using an byte generator.
|
||||
"""
|
||||
|
||||
def __init__(self, iterator: typing.Iterator[bytes]) -> None:
|
||||
self.iterator = iterator
|
||||
self.is_stream_consumed = False
|
||||
|
||||
def can_replay(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_headers(self) -> typing.Dict[str, str]:
|
||||
return {"Transfer-Encoding": "chunked"}
|
||||
def __init__(self, generator: typing.Iterable[bytes]) -> None:
|
||||
self._generator = generator
|
||||
self._is_stream_consumed = False
|
||||
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
if self.is_stream_consumed:
|
||||
if self._is_stream_consumed:
|
||||
raise StreamConsumed()
|
||||
self.is_stream_consumed = True
|
||||
for part in self.iterator:
|
||||
self._is_stream_consumed = True
|
||||
for part in self._generator:
|
||||
yield part
|
||||
|
||||
def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
raise RuntimeError("Attempted to call a async iterator on an sync stream.")
|
||||
|
||||
|
||||
class AsyncIteratorStream(ContentStream):
|
||||
class AsyncGeneratorStream:
|
||||
"""
|
||||
Request content encoded as plain bytes, using an async byte iterator.
|
||||
"""
|
||||
|
||||
def __init__(self, aiterator: typing.AsyncIterator[bytes]) -> None:
|
||||
self.aiterator = aiterator
|
||||
self.is_stream_consumed = False
|
||||
|
||||
def can_replay(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_headers(self) -> typing.Dict[str, str]:
|
||||
return {"Transfer-Encoding": "chunked"}
|
||||
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
raise RuntimeError("Attempted to call a sync iterator on an async stream.")
|
||||
def __init__(self, agenerator: typing.AsyncIterable[bytes]) -> None:
|
||||
self._agenerator = agenerator
|
||||
self._is_stream_consumed = False
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
if self.is_stream_consumed:
|
||||
if self._is_stream_consumed:
|
||||
raise StreamConsumed()
|
||||
self.is_stream_consumed = True
|
||||
async for part in self.aiterator:
|
||||
self._is_stream_consumed = True
|
||||
async for part in self._agenerator:
|
||||
yield part
|
||||
|
||||
|
||||
class JSONStream(ContentStream):
|
||||
class JSONStream:
|
||||
"""
|
||||
Request content encoded as JSON.
|
||||
"""
|
||||
@ -146,7 +99,7 @@ class JSONStream(ContentStream):
|
||||
yield self.body
|
||||
|
||||
|
||||
class URLEncodedStream(ContentStream):
|
||||
class URLEncodedStream:
|
||||
"""
|
||||
Request content as URL encoded form data.
|
||||
"""
|
||||
@ -166,7 +119,7 @@ class URLEncodedStream(ContentStream):
|
||||
yield self.body
|
||||
|
||||
|
||||
class MultipartStream(ContentStream):
|
||||
class MultipartStream:
|
||||
"""
|
||||
Request content as streaming multipart encoded form data.
|
||||
"""
|
||||
@ -208,9 +161,6 @@ class MultipartStream(ContentStream):
|
||||
data = self.render_data()
|
||||
return len(headers) + len(data)
|
||||
|
||||
def can_replay(self) -> bool:
|
||||
return True
|
||||
|
||||
def render(self) -> typing.Iterator[bytes]:
|
||||
yield self.render_headers()
|
||||
yield self.render_data()
|
||||
@ -239,6 +189,7 @@ class MultipartStream(ContentStream):
|
||||
self.filename = filename
|
||||
self.file = fileobj
|
||||
self.content_type = content_type
|
||||
self._consumed = False
|
||||
|
||||
def get_length(self) -> int:
|
||||
headers = self.render_headers()
|
||||
@ -284,17 +235,13 @@ class MultipartStream(ContentStream):
|
||||
yield self._data
|
||||
return
|
||||
|
||||
if self._consumed: # pragma: nocover
|
||||
self.file.seek(0)
|
||||
self._consumed = True
|
||||
|
||||
for chunk in self.file:
|
||||
yield to_bytes(chunk)
|
||||
|
||||
# Get ready for the next replay, if possible.
|
||||
if self.can_replay():
|
||||
assert self.file.seekable()
|
||||
self.file.seek(0)
|
||||
|
||||
def can_replay(self) -> bool:
|
||||
return True if isinstance(self.file, (str, bytes)) else self.file.seekable()
|
||||
|
||||
def render(self) -> typing.Iterator[bytes]:
|
||||
yield self.render_headers()
|
||||
yield from self.render_data()
|
||||
@ -346,9 +293,6 @@ class MultipartStream(ContentStream):
|
||||
|
||||
# Content stream interface.
|
||||
|
||||
def can_replay(self) -> bool:
|
||||
return all(field.can_replay() for field in self.fields)
|
||||
|
||||
def get_headers(self) -> typing.Dict[str, str]:
|
||||
content_length = str(self.get_content_length())
|
||||
content_type = self.content_type
|
||||
@ -369,10 +313,10 @@ def encode_request(
|
||||
files: RequestFiles = None,
|
||||
json: typing.Any = None,
|
||||
boundary: bytes = None,
|
||||
) -> ContentStream:
|
||||
) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
|
||||
"""
|
||||
Handles encoding the given `content`, `data`, `files`, and `json`,
|
||||
returning a `ContentStream` implementation.
|
||||
returning a two-tuple of (<headers>, <stream>).
|
||||
"""
|
||||
if data is not None and not isinstance(data, dict):
|
||||
# We prefer to seperate `content=<bytes|byte iterator|bytes aiterator>`
|
||||
@ -387,39 +331,65 @@ def encode_request(
|
||||
|
||||
if content is not None:
|
||||
if isinstance(content, (str, bytes)):
|
||||
return ByteStream(body=content)
|
||||
elif hasattr(content, "__aiter__"):
|
||||
content = typing.cast(typing.AsyncIterator[bytes], content)
|
||||
return AsyncIteratorStream(aiterator=content)
|
||||
elif hasattr(content, "__iter__"):
|
||||
content = typing.cast(typing.Iterator[bytes], content)
|
||||
return IteratorStream(iterator=content)
|
||||
byte_stream = PlainByteStream(body=content)
|
||||
headers = byte_stream.get_headers()
|
||||
return headers, byte_stream
|
||||
elif isinstance(content, (typing.Iterable, typing.AsyncIterable)):
|
||||
if inspect.isgenerator(content):
|
||||
generator_stream = GeneratorStream(content) # type: ignore
|
||||
return {"Transfer-Encoding": "chunked"}, generator_stream
|
||||
if inspect.isasyncgen(content):
|
||||
agenerator_stream = AsyncGeneratorStream(content) # type: ignore
|
||||
return {"Transfer-Encoding": "chunked"}, agenerator_stream
|
||||
return {"Transfer-Encoding": "chunked"}, content # type: ignore
|
||||
else:
|
||||
raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
|
||||
|
||||
elif data:
|
||||
if files:
|
||||
return MultipartStream(data=data, files=files, boundary=boundary)
|
||||
multipart_stream = MultipartStream(
|
||||
data=data, files=files, boundary=boundary
|
||||
)
|
||||
headers = multipart_stream.get_headers()
|
||||
return headers, multipart_stream
|
||||
else:
|
||||
return URLEncodedStream(data=data)
|
||||
urlencoded_stream = URLEncodedStream(data=data)
|
||||
headers = urlencoded_stream.get_headers()
|
||||
return headers, urlencoded_stream
|
||||
|
||||
elif files:
|
||||
return MultipartStream(data={}, files=files, boundary=boundary)
|
||||
multipart_stream = MultipartStream(data={}, files=files, boundary=boundary)
|
||||
headers = multipart_stream.get_headers()
|
||||
return headers, multipart_stream
|
||||
|
||||
elif json is not None:
|
||||
return JSONStream(json=json)
|
||||
json_stream = JSONStream(json=json)
|
||||
headers = json_stream.get_headers()
|
||||
return headers, json_stream
|
||||
|
||||
return ByteStream(body=b"")
|
||||
byte_stream = PlainByteStream(body=b"")
|
||||
headers = byte_stream.get_headers()
|
||||
return headers, byte_stream
|
||||
|
||||
|
||||
def encode_response(content: ResponseContent = None) -> ContentStream:
|
||||
def encode_response(
|
||||
content: ResponseContent = None,
|
||||
) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
|
||||
if content is None:
|
||||
return ByteStream(b"")
|
||||
byte_stream = PlainByteStream(b"")
|
||||
headers = byte_stream.get_headers()
|
||||
return headers, byte_stream
|
||||
elif isinstance(content, bytes):
|
||||
return ByteStream(body=content)
|
||||
elif isinstance(content, typing.AsyncIterator):
|
||||
return AsyncIteratorStream(aiterator=content)
|
||||
elif isinstance(content, typing.Iterator):
|
||||
return IteratorStream(iterator=content)
|
||||
byte_stream = PlainByteStream(body=content)
|
||||
headers = byte_stream.get_headers()
|
||||
return headers, byte_stream
|
||||
elif isinstance(content, (typing.Iterable, typing.AsyncIterable)):
|
||||
if inspect.isgenerator(content):
|
||||
generator_stream = GeneratorStream(content) # type: ignore
|
||||
return {"Transfer-Encoding": "chunked"}, generator_stream
|
||||
elif inspect.isasyncgen(content):
|
||||
agenerator_stream = AsyncGeneratorStream(content) # type: ignore
|
||||
return {"Transfer-Encoding": "chunked"}, agenerator_stream
|
||||
return {"Transfer-Encoding": "chunked"}, content # type: ignore
|
||||
|
||||
raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
|
||||
|
||||
@ -207,13 +207,6 @@ class TooManyRedirects(RequestError):
|
||||
"""
|
||||
|
||||
|
||||
class RequestBodyUnavailable(RequestError):
|
||||
"""
|
||||
Had to send the request again, but the request body was streaming, and is
|
||||
no longer available.
|
||||
"""
|
||||
|
||||
|
||||
# Client errors
|
||||
|
||||
|
||||
@ -283,14 +276,18 @@ class StreamError(Exception):
|
||||
|
||||
class StreamConsumed(StreamError):
|
||||
"""
|
||||
Attempted to read or stream response content, but the content has already
|
||||
Attempted to read or stream content, but the content has already
|
||||
been streamed.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
message = (
|
||||
"Attempted to read or stream response content, but the content has "
|
||||
"already been streamed."
|
||||
"Attempted to read or stream some content, but the content has "
|
||||
"already been streamed. For requests, this could be due to passing "
|
||||
"a generator as request content, and then receiving a redirect "
|
||||
"response or a secondary request as part of an authentication flow."
|
||||
"For responses, this could be due to attempting to stream the response "
|
||||
"content more than once."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ from urllib.parse import parse_qsl, quote, unquote, urlencode
|
||||
import rfc3986
|
||||
import rfc3986.exceptions
|
||||
|
||||
from ._content_streams import ByteStream, ContentStream, encode_request, encode_response
|
||||
from ._content_streams import PlainByteStream, encode_request, encode_response
|
||||
from ._decoders import (
|
||||
SUPPORTED_DECODERS,
|
||||
ContentDecoder,
|
||||
@ -37,6 +37,7 @@ from ._exceptions import (
|
||||
)
|
||||
from ._status_codes import codes
|
||||
from ._types import (
|
||||
ByteStream,
|
||||
CookieTypes,
|
||||
HeaderTypes,
|
||||
PrimitiveData,
|
||||
@ -606,7 +607,7 @@ class Request:
|
||||
data: RequestData = None,
|
||||
files: RequestFiles = None,
|
||||
json: typing.Any = None,
|
||||
stream: ContentStream = None,
|
||||
stream: ByteStream = None,
|
||||
):
|
||||
if isinstance(method, bytes):
|
||||
self.method = method.decode("ascii").upper()
|
||||
@ -618,14 +619,28 @@ class Request:
|
||||
Cookies(cookies).set_cookie_header(self)
|
||||
|
||||
if stream is not None:
|
||||
# There's an important distinction between `Request(content=...)`,
|
||||
# and `Request(stream=...)`.
|
||||
#
|
||||
# Using `content=...` implies automatically populated content headers,
|
||||
# of either `Content-Length: ...` or `Transfer-Encoding: chunked`.
|
||||
#
|
||||
# Using `stream=...` will not automatically include any content headers.
|
||||
#
|
||||
# As an end-user you don't really need `stream=...`. It's only
|
||||
# useful when:
|
||||
#
|
||||
# * Preserving the request stream when copying requests, eg for redirects.
|
||||
# * Creating request instances on the *server-side* of the transport API.
|
||||
self.stream = stream
|
||||
self._prepare({})
|
||||
else:
|
||||
self.stream = encode_request(content, data, files, json)
|
||||
headers, stream = encode_request(content, data, files, json)
|
||||
self._prepare(headers)
|
||||
self.stream = stream
|
||||
|
||||
self._prepare()
|
||||
|
||||
def _prepare(self) -> None:
|
||||
for key, value in self.stream.get_headers().items():
|
||||
def _prepare(self, default_headers: typing.Dict[str, str]) -> None:
|
||||
for key, value in default_headers.items():
|
||||
# Ignore Transfer-Encoding if the Content-Length has been set explicitly.
|
||||
if key.lower() == "transfer-encoding" and "content-length" in self.headers:
|
||||
continue
|
||||
@ -657,11 +672,12 @@ class Request:
|
||||
Read and return the request content.
|
||||
"""
|
||||
if not hasattr(self, "_content"):
|
||||
assert isinstance(self.stream, typing.Iterable)
|
||||
self._content = b"".join(self.stream)
|
||||
# If a streaming request has been read entirely into memory, then
|
||||
# we can replace the stream with a raw bytes implementation,
|
||||
# to ensure that any non-replayable streams can still be used.
|
||||
self.stream = ByteStream(self._content)
|
||||
self.stream = PlainByteStream(self._content)
|
||||
return self._content
|
||||
|
||||
async def aread(self) -> bytes:
|
||||
@ -669,11 +685,12 @@ class Request:
|
||||
Read and return the request content.
|
||||
"""
|
||||
if not hasattr(self, "_content"):
|
||||
assert isinstance(self.stream, typing.AsyncIterable)
|
||||
self._content = b"".join([part async for part in self.stream])
|
||||
# If a streaming request has been read entirely into memory, then
|
||||
# we can replace the stream with a raw bytes implementation,
|
||||
# to ensure that any non-replayable streams can still be used.
|
||||
self.stream = ByteStream(self._content)
|
||||
self.stream = PlainByteStream(self._content)
|
||||
return self._content
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -690,10 +707,10 @@ class Response:
|
||||
request: Request = None,
|
||||
http_version: str = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: ContentStream = None,
|
||||
content: ResponseContent = None,
|
||||
stream: ByteStream = None,
|
||||
history: typing.List["Response"] = None,
|
||||
elapsed_func: typing.Callable = None,
|
||||
on_close: typing.Callable = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.http_version = http_version
|
||||
@ -704,20 +721,41 @@ class Response:
|
||||
self.call_next: typing.Optional[typing.Callable] = None
|
||||
|
||||
self.history = [] if history is None else list(history)
|
||||
self._elapsed_func = elapsed_func
|
||||
self._on_close = on_close
|
||||
|
||||
self.is_closed = False
|
||||
self.is_stream_consumed = False
|
||||
|
||||
if stream is not None:
|
||||
self._raw_stream = stream
|
||||
# There's an important distinction between `Response(content=...)`,
|
||||
# and `Response(stream=...)`.
|
||||
#
|
||||
# Using `content=...` implies automatically populated content headers,
|
||||
# of either `Content-Length: ...` or `Transfer-Encoding: chunked`.
|
||||
#
|
||||
# Using `stream=...` will not automatically include any content headers.
|
||||
#
|
||||
# As an end-user you don't really need `stream=...`. It's only
|
||||
# useful when creating response instances having received a stream
|
||||
# from the transport API.
|
||||
self.stream = stream
|
||||
else:
|
||||
self._raw_stream = encode_response(content)
|
||||
headers, stream = encode_response(content)
|
||||
self._prepare(headers)
|
||||
self.stream = stream
|
||||
if content is None or isinstance(content, bytes):
|
||||
# Load the response body, except for streaming content.
|
||||
self.read()
|
||||
|
||||
self._num_bytes_downloaded = 0
|
||||
|
||||
def _prepare(self, default_headers: typing.Dict[str, str]) -> None:
|
||||
for key, value in default_headers.items():
|
||||
# Ignore Transfer-Encoding if the Content-Length has been set explicitly.
|
||||
if key.lower() == "transfer-encoding" and "content-length" in self.headers:
|
||||
continue
|
||||
self.headers.setdefault(key, value)
|
||||
|
||||
@property
|
||||
def elapsed(self) -> datetime.timedelta:
|
||||
"""
|
||||
@ -729,7 +767,11 @@ class Response:
|
||||
"'.elapsed' may only be accessed after the response "
|
||||
"has been read or closed."
|
||||
)
|
||||
return datetime.timedelta(seconds=self._elapsed)
|
||||
return self._elapsed
|
||||
|
||||
@elapsed.setter
|
||||
def elapsed(self, elapsed: datetime.timedelta) -> None:
|
||||
self._elapsed = elapsed
|
||||
|
||||
@property
|
||||
def request(self) -> Request:
|
||||
@ -963,11 +1005,13 @@ class Response:
|
||||
raise StreamConsumed()
|
||||
if self.is_closed:
|
||||
raise ResponseClosed()
|
||||
if not isinstance(self.stream, typing.Iterable):
|
||||
raise RuntimeError("Attempted to call a sync iterator on an async stream.")
|
||||
|
||||
self.is_stream_consumed = True
|
||||
self._num_bytes_downloaded = 0
|
||||
with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
|
||||
for part in self._raw_stream:
|
||||
for part in self.stream:
|
||||
self._num_bytes_downloaded += len(part)
|
||||
yield part
|
||||
self.close()
|
||||
@ -992,9 +1036,8 @@ class Response:
|
||||
"""
|
||||
if not self.is_closed:
|
||||
self.is_closed = True
|
||||
if self._elapsed_func is not None:
|
||||
self._elapsed = self._elapsed_func()
|
||||
self._raw_stream.close()
|
||||
if self._on_close is not None:
|
||||
self._on_close(self)
|
||||
|
||||
async def aread(self) -> bytes:
|
||||
"""
|
||||
@ -1047,11 +1090,13 @@ class Response:
|
||||
raise StreamConsumed()
|
||||
if self.is_closed:
|
||||
raise ResponseClosed()
|
||||
if not isinstance(self.stream, typing.AsyncIterable):
|
||||
raise RuntimeError("Attempted to call a async iterator on a sync stream.")
|
||||
|
||||
self.is_stream_consumed = True
|
||||
self._num_bytes_downloaded = 0
|
||||
with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
|
||||
async for part in self._raw_stream:
|
||||
async for part in self.stream:
|
||||
self._num_bytes_downloaded += len(part)
|
||||
yield part
|
||||
await self.aclose()
|
||||
@ -1075,9 +1120,8 @@ class Response:
|
||||
"""
|
||||
if not self.is_closed:
|
||||
self.is_closed = True
|
||||
if self._elapsed_func is not None:
|
||||
self._elapsed = await self._elapsed_func()
|
||||
await self._raw_stream.aclose()
|
||||
if self._on_close is not None:
|
||||
await self._on_close(self)
|
||||
|
||||
|
||||
class Cookies(MutableMapping):
|
||||
|
||||
@ -7,10 +7,10 @@ from http.cookiejar import CookieJar
|
||||
from typing import (
|
||||
IO,
|
||||
TYPE_CHECKING,
|
||||
AsyncIterator,
|
||||
AsyncIterable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
@ -66,8 +66,9 @@ AuthTypes = Union[
|
||||
None,
|
||||
]
|
||||
|
||||
RequestContent = Union[str, bytes, Iterator[bytes], AsyncIterator[bytes]]
|
||||
ResponseContent = Union[bytes, Iterator[bytes], AsyncIterator[bytes]]
|
||||
RequestContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
|
||||
ResponseContent = Union[bytes, Iterable[bytes], AsyncIterable[bytes]]
|
||||
ByteStream = Union[Iterable[bytes], AsyncIterable[bytes]]
|
||||
|
||||
RequestData = dict
|
||||
|
||||
|
||||
@ -13,16 +13,7 @@ import typing
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
from httpx import (
|
||||
URL,
|
||||
Auth,
|
||||
BasicAuth,
|
||||
DigestAuth,
|
||||
ProtocolError,
|
||||
Request,
|
||||
RequestBodyUnavailable,
|
||||
Response,
|
||||
)
|
||||
from httpx import URL, Auth, BasicAuth, DigestAuth, ProtocolError, Request, Response
|
||||
from tests.utils import AsyncMockTransport, MockTransport
|
||||
|
||||
from ..common import FIXTURES_DIR
|
||||
@ -617,13 +608,13 @@ def test_sync_auth_history() -> None:
|
||||
async def test_digest_auth_unavailable_streaming_body():
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
app = App()
|
||||
app = DigestApp()
|
||||
|
||||
async def streaming_body():
|
||||
yield b"Example request body" # pragma: nocover
|
||||
|
||||
async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
|
||||
with pytest.raises(RequestBodyUnavailable):
|
||||
with pytest.raises(httpx.StreamConsumed):
|
||||
await client.post(url, data=streaming_body(), auth=auth)
|
||||
|
||||
|
||||
|
||||
@ -334,7 +334,7 @@ def test_cannot_redirect_streaming_body():
|
||||
def streaming_body():
|
||||
yield b"Example request body" # pragma: nocover
|
||||
|
||||
with pytest.raises(httpx.RequestBodyUnavailable):
|
||||
with pytest.raises(httpx.StreamConsumed):
|
||||
client.post(url, content=streaming_body())
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import typing
|
||||
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
@ -18,6 +20,40 @@ def test_content_length_header():
|
||||
assert request.headers["Content-Length"] == "8"
|
||||
|
||||
|
||||
def test_iterable_content():
|
||||
class Content:
|
||||
def __iter__(self):
|
||||
yield b"test 123" # pragma: nocover
|
||||
|
||||
request = httpx.Request("POST", "http://example.org", content=Content())
|
||||
assert request.headers == httpx.Headers(
|
||||
{"Host": "example.org", "Transfer-Encoding": "chunked"}
|
||||
)
|
||||
|
||||
|
||||
def test_generator_with_transfer_encoding_header():
|
||||
def content():
|
||||
yield b"test 123" # pragma: nocover
|
||||
|
||||
request = httpx.Request("POST", "http://example.org", content=content())
|
||||
assert request.headers == httpx.Headers(
|
||||
{"Host": "example.org", "Transfer-Encoding": "chunked"}
|
||||
)
|
||||
|
||||
|
||||
def test_generator_with_content_length_header():
|
||||
def content():
|
||||
yield b"test 123" # pragma: nocover
|
||||
|
||||
headers = {"Content-Length": "8"}
|
||||
request = httpx.Request(
|
||||
"POST", "http://example.org", content=content(), headers=headers
|
||||
)
|
||||
assert request.headers == httpx.Headers(
|
||||
{"Host": "example.org", "Content-Length": "8"}
|
||||
)
|
||||
|
||||
|
||||
def test_url_encoded_data():
|
||||
request = httpx.Request("POST", "http://example.org", data={"test": "123"})
|
||||
request.read()
|
||||
@ -51,6 +87,8 @@ def test_read_and_stream_data():
|
||||
# Needed for cases such as authentication classes that read the request body.
|
||||
request = httpx.Request("POST", "http://example.org", json={"test": 123})
|
||||
request.read()
|
||||
assert request.stream is not None
|
||||
assert isinstance(request.stream, typing.Iterable)
|
||||
content = b"".join([part for part in request.stream])
|
||||
assert content == request.content
|
||||
|
||||
@ -61,6 +99,8 @@ async def test_aread_and_stream_data():
|
||||
# Needed for cases such as authentication classes that read the request body.
|
||||
request = httpx.Request("POST", "http://example.org", json={"test": 123})
|
||||
await request.aread()
|
||||
assert request.stream is not None
|
||||
assert isinstance(request.stream, typing.AsyncIterable)
|
||||
content = b"".join([part async for part in request.stream])
|
||||
assert content == request.content
|
||||
|
||||
@ -68,7 +108,7 @@ async def test_aread_and_stream_data():
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_access_content_without_read():
|
||||
# Ensure a request may still be streamed if it has been read.
|
||||
# Needed for cases such as authentication classes that read the request body.
|
||||
# Needed for cases such as authentication classes that read the request body.
|
||||
request = httpx.Request("POST", "http://example.org", json={"test": 123})
|
||||
with pytest.raises(httpx.RequestNotRead):
|
||||
request.content
|
||||
|
||||
@ -7,6 +7,12 @@ import pytest
|
||||
import httpx
|
||||
|
||||
|
||||
class StreamingBody:
|
||||
def __iter__(self):
|
||||
yield b"Hello, "
|
||||
yield b"world!"
|
||||
|
||||
|
||||
def streaming_body():
|
||||
yield b"Hello, "
|
||||
yield b"world!"
|
||||
@ -230,6 +236,21 @@ def test_read():
|
||||
assert response.is_closed
|
||||
|
||||
|
||||
def test_empty_read():
|
||||
response = httpx.Response(200)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.text == ""
|
||||
assert response.encoding is None
|
||||
assert response.is_closed
|
||||
|
||||
content = response.read()
|
||||
|
||||
assert content == b""
|
||||
assert response.content == b""
|
||||
assert response.is_closed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aread():
|
||||
response = httpx.Response(
|
||||
@ -249,6 +270,22 @@ async def test_aread():
|
||||
assert response.is_closed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_aread():
|
||||
response = httpx.Response(200)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.text == ""
|
||||
assert response.encoding is None
|
||||
assert response.is_closed
|
||||
|
||||
content = await response.aread()
|
||||
|
||||
assert content == b""
|
||||
assert response.content == b""
|
||||
assert response.is_closed
|
||||
|
||||
|
||||
def test_iter_raw():
|
||||
response = httpx.Response(
|
||||
200,
|
||||
@ -261,6 +298,28 @@ def test_iter_raw():
|
||||
assert raw == b"Hello, world!"
|
||||
|
||||
|
||||
def test_iter_raw_on_iterable():
|
||||
response = httpx.Response(
|
||||
200,
|
||||
content=StreamingBody(),
|
||||
)
|
||||
|
||||
raw = b""
|
||||
for part in response.iter_raw():
|
||||
raw += part
|
||||
assert raw == b"Hello, world!"
|
||||
|
||||
|
||||
def test_iter_raw_on_async():
|
||||
response = httpx.Response(
|
||||
200,
|
||||
content=async_streaming_body(),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
[part for part in response.iter_raw()]
|
||||
|
||||
|
||||
def test_iter_raw_increments_updates_counter():
|
||||
response = httpx.Response(200, content=streaming_body())
|
||||
|
||||
@ -280,6 +339,17 @@ async def test_aiter_raw():
|
||||
assert raw == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiter_raw_on_sync():
|
||||
response = httpx.Response(
|
||||
200,
|
||||
content=streaming_body(),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
[part async for part in response.aiter_raw()]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiter_raw_increments_updates_counter():
|
||||
response = httpx.Response(200, content=async_streaming_body())
|
||||
@ -610,3 +680,20 @@ def test_cannot_access_unset_request():
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
response.request
|
||||
|
||||
|
||||
def test_generator_with_transfer_encoding_header():
|
||||
def content():
|
||||
yield b"test 123" # pragma: nocover
|
||||
|
||||
response = httpx.Response(200, content=content())
|
||||
assert response.headers == httpx.Headers({"Transfer-Encoding": "chunked"})
|
||||
|
||||
|
||||
def test_generator_with_content_length_header():
|
||||
def content():
|
||||
yield b"test 123" # pragma: nocover
|
||||
|
||||
headers = {"Content-Length": "8"}
|
||||
response = httpx.Response(200, content=content(), headers=headers)
|
||||
assert response.headers == httpx.Headers({"Content-Length": "8"})
|
||||
|
||||
@ -1,53 +1,48 @@
|
||||
import io
|
||||
import typing
|
||||
|
||||
import pytest
|
||||
|
||||
from httpx import StreamConsumed
|
||||
from httpx._content_streams import ContentStream, encode_request, encode_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_content():
|
||||
stream = ContentStream()
|
||||
sync_content = b"".join([part for part in stream])
|
||||
async_content = b"".join([part async for part in stream])
|
||||
|
||||
assert stream.can_replay()
|
||||
assert stream.get_headers() == {}
|
||||
assert sync_content == b""
|
||||
assert async_content == b""
|
||||
from httpx._content_streams import encode_request, encode_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_content():
|
||||
stream = encode_request()
|
||||
headers, stream = encode_request()
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
assert isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
sync_content = b"".join([part for part in stream])
|
||||
async_content = b"".join([part async for part in stream])
|
||||
|
||||
assert stream.can_replay()
|
||||
assert stream.get_headers() == {}
|
||||
assert headers == {}
|
||||
assert sync_content == b""
|
||||
assert async_content == b""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bytes_content():
|
||||
stream = encode_request(content=b"Hello, world!")
|
||||
headers, stream = encode_request(content=b"Hello, world!")
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
assert isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
sync_content = b"".join([part for part in stream])
|
||||
async_content = b"".join([part async for part in stream])
|
||||
|
||||
assert stream.can_replay()
|
||||
assert stream.get_headers() == {"Content-Length": "13"}
|
||||
assert headers == {"Content-Length": "13"}
|
||||
assert sync_content == b"Hello, world!"
|
||||
assert async_content == b"Hello, world!"
|
||||
|
||||
# Support 'data' for compat with requests.
|
||||
stream = encode_request(data=b"Hello, world!") # type: ignore
|
||||
headers, stream = encode_request(data=b"Hello, world!") # type: ignore
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
assert isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
sync_content = b"".join([part for part in stream])
|
||||
async_content = b"".join([part async for part in stream])
|
||||
|
||||
assert stream.can_replay()
|
||||
assert stream.get_headers() == {"Content-Length": "13"}
|
||||
assert headers == {"Content-Length": "13"}
|
||||
assert sync_content == b"Hello, world!"
|
||||
assert async_content == b"Hello, world!"
|
||||
|
||||
@ -58,25 +53,26 @@ async def test_iterator_content():
|
||||
yield b"Hello, "
|
||||
yield b"world!"
|
||||
|
||||
stream = encode_request(content=hello_world())
|
||||
headers, stream = encode_request(content=hello_world())
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
assert not isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
content = b"".join([part for part in stream])
|
||||
|
||||
assert not stream.can_replay()
|
||||
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
|
||||
assert headers == {"Transfer-Encoding": "chunked"}
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
[part async for part in stream]
|
||||
|
||||
with pytest.raises(StreamConsumed):
|
||||
[part for part in stream]
|
||||
|
||||
# Support 'data' for compat with requests.
|
||||
stream = encode_request(data=hello_world()) # type: ignore
|
||||
headers, stream = encode_request(data=hello_world()) # type: ignore
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
assert not isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
content = b"".join([part for part in stream])
|
||||
|
||||
assert not stream.can_replay()
|
||||
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
|
||||
assert headers == {"Transfer-Encoding": "chunked"}
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
|
||||
@ -86,36 +82,39 @@ async def test_aiterator_content():
|
||||
yield b"Hello, "
|
||||
yield b"world!"
|
||||
|
||||
stream = encode_request(content=hello_world())
|
||||
headers, stream = encode_request(content=hello_world())
|
||||
assert not isinstance(stream, typing.Iterable)
|
||||
assert isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
content = b"".join([part async for part in stream])
|
||||
|
||||
assert not stream.can_replay()
|
||||
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
|
||||
assert headers == {"Transfer-Encoding": "chunked"}
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
[part for part in stream]
|
||||
|
||||
with pytest.raises(StreamConsumed):
|
||||
[part async for part in stream]
|
||||
|
||||
# Support 'data' for compat with requests.
|
||||
stream = encode_request(data=hello_world()) # type: ignore
|
||||
headers, stream = encode_request(data=hello_world()) # type: ignore
|
||||
assert not isinstance(stream, typing.Iterable)
|
||||
assert isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
content = b"".join([part async for part in stream])
|
||||
|
||||
assert not stream.can_replay()
|
||||
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
|
||||
assert headers == {"Transfer-Encoding": "chunked"}
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_content():
|
||||
stream = encode_request(json={"Hello": "world!"})
|
||||
headers, stream = encode_request(json={"Hello": "world!"})
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
assert isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
sync_content = b"".join([part for part in stream])
|
||||
async_content = b"".join([part async for part in stream])
|
||||
|
||||
assert stream.can_replay()
|
||||
assert stream.get_headers() == {
|
||||
assert headers == {
|
||||
"Content-Length": "19",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
@ -125,12 +124,14 @@ async def test_json_content():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_urlencoded_content():
|
||||
stream = encode_request(data={"Hello": "world!"})
|
||||
headers, stream = encode_request(data={"Hello": "world!"})
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
assert isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
sync_content = b"".join([part for part in stream])
|
||||
async_content = b"".join([part async for part in stream])
|
||||
|
||||
assert stream.can_replay()
|
||||
assert stream.get_headers() == {
|
||||
assert headers == {
|
||||
"Content-Length": "14",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
@ -141,12 +142,14 @@ async def test_urlencoded_content():
|
||||
@pytest.mark.asyncio
|
||||
async def test_multipart_files_content():
|
||||
files = {"file": io.BytesIO(b"<file content>")}
|
||||
stream = encode_request(files=files, boundary=b"+++")
|
||||
headers, stream = encode_request(files=files, boundary=b"+++")
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
assert isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
sync_content = b"".join([part for part in stream])
|
||||
async_content = b"".join([part async for part in stream])
|
||||
|
||||
assert stream.can_replay()
|
||||
assert stream.get_headers() == {
|
||||
assert headers == {
|
||||
"Content-Length": "138",
|
||||
"Content-Type": "multipart/form-data; boundary=+++",
|
||||
}
|
||||
@ -176,12 +179,14 @@ async def test_multipart_files_content():
|
||||
async def test_multipart_data_and_files_content():
|
||||
data = {"message": "Hello, world!"}
|
||||
files = {"file": io.BytesIO(b"<file content>")}
|
||||
stream = encode_request(data=data, files=files, boundary=b"+++")
|
||||
headers, stream = encode_request(data=data, files=files, boundary=b"+++")
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
assert isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
sync_content = b"".join([part for part in stream])
|
||||
async_content = b"".join([part async for part in stream])
|
||||
|
||||
assert stream.can_replay()
|
||||
assert stream.get_headers() == {
|
||||
assert headers == {
|
||||
"Content-Length": "210",
|
||||
"Content-Type": "multipart/form-data; boundary=+++",
|
||||
}
|
||||
@ -217,12 +222,14 @@ async def test_multipart_data_and_files_content():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_request():
|
||||
stream = encode_request(data={}, files={})
|
||||
headers, stream = encode_request(data={}, files={})
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
assert isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
sync_content = b"".join([part for part in stream])
|
||||
async_content = b"".join([part async for part in stream])
|
||||
|
||||
assert stream.can_replay()
|
||||
assert stream.get_headers() == {}
|
||||
assert headers == {}
|
||||
assert sync_content == b""
|
||||
assert async_content == b""
|
||||
|
||||
@ -238,12 +245,14 @@ async def test_multipart_multiple_files_single_input_content():
|
||||
("file", io.BytesIO(b"<file content 1>")),
|
||||
("file", io.BytesIO(b"<file content 2>")),
|
||||
]
|
||||
stream = encode_request(files=files, boundary=b"+++")
|
||||
headers, stream = encode_request(files=files, boundary=b"+++")
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
assert isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
sync_content = b"".join([part for part in stream])
|
||||
async_content = b"".join([part async for part in stream])
|
||||
|
||||
assert stream.can_replay()
|
||||
assert stream.get_headers() == {
|
||||
assert headers == {
|
||||
"Content-Length": "271",
|
||||
"Content-Type": "multipart/form-data; boundary=+++",
|
||||
}
|
||||
@ -281,24 +290,28 @@ async def test_multipart_multiple_files_single_input_content():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_empty_content():
|
||||
stream = encode_response()
|
||||
headers, stream = encode_response()
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
assert isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
sync_content = b"".join([part for part in stream])
|
||||
async_content = b"".join([part async for part in stream])
|
||||
|
||||
assert stream.can_replay()
|
||||
assert stream.get_headers() == {}
|
||||
assert headers == {}
|
||||
assert sync_content == b""
|
||||
assert async_content == b""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_bytes_content():
|
||||
stream = encode_response(content=b"Hello, world!")
|
||||
headers, stream = encode_response(content=b"Hello, world!")
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
assert isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
sync_content = b"".join([part for part in stream])
|
||||
async_content = b"".join([part async for part in stream])
|
||||
|
||||
assert stream.can_replay()
|
||||
assert stream.get_headers() == {"Content-Length": "13"}
|
||||
assert headers == {"Content-Length": "13"}
|
||||
assert sync_content == b"Hello, world!"
|
||||
assert async_content == b"Hello, world!"
|
||||
|
||||
@ -309,16 +322,15 @@ async def test_response_iterator_content():
|
||||
yield b"Hello, "
|
||||
yield b"world!"
|
||||
|
||||
stream = encode_response(content=hello_world())
|
||||
headers, stream = encode_response(content=hello_world())
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
assert not isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
content = b"".join([part for part in stream])
|
||||
|
||||
assert not stream.can_replay()
|
||||
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
|
||||
assert headers == {"Transfer-Encoding": "chunked"}
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
[part async for part in stream]
|
||||
|
||||
with pytest.raises(StreamConsumed):
|
||||
[part for part in stream]
|
||||
|
||||
@ -329,16 +341,15 @@ async def test_response_aiterator_content():
|
||||
yield b"Hello, "
|
||||
yield b"world!"
|
||||
|
||||
stream = encode_response(content=hello_world())
|
||||
headers, stream = encode_response(content=hello_world())
|
||||
assert not isinstance(stream, typing.Iterable)
|
||||
assert isinstance(stream, typing.AsyncIterable)
|
||||
|
||||
content = b"".join([part async for part in stream])
|
||||
|
||||
assert not stream.can_replay()
|
||||
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
|
||||
assert headers == {"Transfer-Encoding": "chunked"}
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
[part for part in stream]
|
||||
|
||||
with pytest.raises(StreamConsumed):
|
||||
[part async for part in stream]
|
||||
|
||||
|
||||
@ -110,9 +110,8 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:
|
||||
with mock.patch("os.urandom", return_value=os.urandom(16)):
|
||||
boundary = os.urandom(16).hex()
|
||||
|
||||
stream = encode_request(data=data, files=files)
|
||||
headers, stream = encode_request(data=data, files=files)
|
||||
assert isinstance(stream, MultipartStream)
|
||||
assert stream.can_replay()
|
||||
|
||||
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
|
||||
content = (
|
||||
@ -128,7 +127,7 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:
|
||||
"--{0}--\r\n"
|
||||
"".format(boundary).encode("ascii")
|
||||
)
|
||||
assert stream.get_headers()["Content-Length"] == str(len(content))
|
||||
assert headers["Content-Length"] == str(len(content))
|
||||
assert b"".join(stream) == content
|
||||
|
||||
|
||||
@ -137,9 +136,8 @@ def test_multipart_encode_files_allows_filenames_as_none() -> None:
|
||||
with mock.patch("os.urandom", return_value=os.urandom(16)):
|
||||
boundary = os.urandom(16).hex()
|
||||
|
||||
stream = encode_request(data={}, files=files)
|
||||
headers, stream = encode_request(data={}, files=files)
|
||||
assert isinstance(stream, MultipartStream)
|
||||
assert stream.can_replay()
|
||||
|
||||
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
|
||||
assert b"".join(stream) == (
|
||||
@ -164,9 +162,8 @@ def test_multipart_encode_files_guesses_correct_content_type(
|
||||
with mock.patch("os.urandom", return_value=os.urandom(16)):
|
||||
boundary = os.urandom(16).hex()
|
||||
|
||||
stream = encode_request(data={}, files=files)
|
||||
headers, stream = encode_request(data={}, files=files)
|
||||
assert isinstance(stream, MultipartStream)
|
||||
assert stream.can_replay()
|
||||
|
||||
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
|
||||
assert b"".join(stream) == (
|
||||
@ -188,9 +185,8 @@ def test_multipart_encode_files_allows_bytes_or_str_content(
|
||||
with mock.patch("os.urandom", return_value=os.urandom(16)):
|
||||
boundary = os.urandom(16).hex()
|
||||
|
||||
stream = encode_request(data={}, files=files)
|
||||
headers, stream = encode_request(data={}, files=files)
|
||||
assert isinstance(stream, MultipartStream)
|
||||
assert stream.can_replay()
|
||||
|
||||
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
|
||||
content = (
|
||||
@ -200,7 +196,7 @@ def test_multipart_encode_files_allows_bytes_or_str_content(
|
||||
"--{0}--\r\n"
|
||||
"".format(boundary, output).encode("ascii")
|
||||
)
|
||||
assert stream.get_headers()["Content-Length"] == str(len(content))
|
||||
assert headers["Content-Length"] == str(len(content))
|
||||
assert b"".join(stream) == content
|
||||
|
||||
|
||||
@ -214,9 +210,6 @@ def test_multipart_encode_non_seekable_filelike() -> None:
|
||||
def __init__(self, iterator: typing.Iterator[bytes]) -> None:
|
||||
self._iterator = iterator
|
||||
|
||||
def seekable(self) -> bool:
|
||||
return False
|
||||
|
||||
def read(self, *args: typing.Any) -> bytes:
|
||||
return b"".join(self._iterator)
|
||||
|
||||
@ -226,8 +219,8 @@ def test_multipart_encode_non_seekable_filelike() -> None:
|
||||
|
||||
fileobj: typing.Any = IteratorIO(data())
|
||||
files = {"file": fileobj}
|
||||
stream = encode_request(files=files, boundary=b"+++")
|
||||
assert not stream.can_replay()
|
||||
headers, stream = encode_request(files=files, boundary=b"+++")
|
||||
assert isinstance(stream, typing.Iterable)
|
||||
|
||||
content = (
|
||||
b"--+++\r\n"
|
||||
@ -237,7 +230,7 @@ def test_multipart_encode_non_seekable_filelike() -> None:
|
||||
b"HelloWorld\r\n"
|
||||
b"--+++--\r\n"
|
||||
)
|
||||
assert stream.get_headers() == {
|
||||
assert headers == {
|
||||
"Content-Type": "multipart/form-data; boundary=+++",
|
||||
"Content-Length": str(len(content)),
|
||||
}
|
||||
|
||||
@ -36,22 +36,11 @@ class MockTransport(httpcore.SyncHTTPTransport):
|
||||
stream: httpcore.SyncByteStream = None,
|
||||
timeout: Mapping[str, Optional[float]] = None,
|
||||
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.SyncByteStream]:
|
||||
request_headers = httpx.Headers(headers)
|
||||
content = (
|
||||
(item for item in stream)
|
||||
if stream
|
||||
and (
|
||||
"Content-Length" in request_headers
|
||||
or "Transfer-Encoding" in request_headers
|
||||
)
|
||||
else None
|
||||
)
|
||||
|
||||
request = httpx.Request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=request_headers,
|
||||
content=content,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
)
|
||||
request.read()
|
||||
response = self.handler(request)
|
||||
@ -60,13 +49,13 @@ class MockTransport(httpcore.SyncHTTPTransport):
|
||||
response.status_code,
|
||||
response.reason_phrase.encode("ascii"),
|
||||
response.headers.raw,
|
||||
response._raw_stream,
|
||||
response.stream,
|
||||
)
|
||||
|
||||
|
||||
class AsyncMockTransport(httpcore.AsyncHTTPTransport):
|
||||
def __init__(self, handler: Callable) -> None:
|
||||
self.impl = MockTransport(handler)
|
||||
self.handler = handler
|
||||
|
||||
async def request(
|
||||
self,
|
||||
@ -76,28 +65,18 @@ class AsyncMockTransport(httpcore.AsyncHTTPTransport):
|
||||
stream: httpcore.AsyncByteStream = None,
|
||||
timeout: Mapping[str, Optional[float]] = None,
|
||||
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:
|
||||
content = (
|
||||
httpcore.PlainByteStream(b"".join([part async for part in stream]))
|
||||
if stream
|
||||
else httpcore.PlainByteStream(b"")
|
||||
request = httpx.Request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
(
|
||||
http_version,
|
||||
status_code,
|
||||
reason_phrase,
|
||||
headers,
|
||||
response_stream,
|
||||
) = self.impl.request(
|
||||
method, url, headers=headers, stream=content, timeout=timeout
|
||||
)
|
||||
|
||||
content = httpcore.PlainByteStream(b"".join([part for part in response_stream]))
|
||||
|
||||
await request.aread()
|
||||
response = self.handler(request)
|
||||
return (
|
||||
http_version,
|
||||
status_code,
|
||||
reason_phrase,
|
||||
headers,
|
||||
content,
|
||||
(response.http_version or "HTTP/1.1").encode("ascii"),
|
||||
response.status_code,
|
||||
response.reason_phrase.encode("ascii"),
|
||||
response.headers.raw,
|
||||
response.stream,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user