Support Response(text=...), Response(html=...), Response(json=...) (#1297)
* Refactor content_streams internally * Tidy up multipart * Use ByteStream annotation internally * Support Response(text=...), Response(html=...), Response(json=...) * Add tests for Response(text=..., html=..., json=...)
This commit is contained in:
parent
8ee08afe96
commit
f3c29416f1
@ -1,6 +1,15 @@
|
||||
import inspect
|
||||
import typing
|
||||
from json import dumps as json_dumps
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from ._exceptions import StreamConsumed
|
||||
@ -22,10 +31,10 @@ class PlainByteStream:
|
||||
def __init__(self, body: bytes) -> None:
|
||||
self._body = body
|
||||
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
yield self._body
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
async def __aiter__(self) -> AsyncIterator[bytes]:
|
||||
yield self._body
|
||||
|
||||
|
||||
@ -34,11 +43,11 @@ class GeneratorStream:
|
||||
Request content encoded as plain bytes, using an byte generator.
|
||||
"""
|
||||
|
||||
def __init__(self, generator: typing.Iterable[bytes]) -> None:
|
||||
def __init__(self, generator: Iterable[bytes]) -> None:
|
||||
self._generator = generator
|
||||
self._is_stream_consumed = False
|
||||
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
if self._is_stream_consumed:
|
||||
raise StreamConsumed()
|
||||
|
||||
@ -52,11 +61,11 @@ class AsyncGeneratorStream:
|
||||
Request content encoded as plain bytes, using an async byte iterator.
|
||||
"""
|
||||
|
||||
def __init__(self, agenerator: typing.AsyncIterable[bytes]) -> None:
|
||||
def __init__(self, agenerator: AsyncIterable[bytes]) -> None:
|
||||
self._agenerator = agenerator
|
||||
self._is_stream_consumed = False
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
async def __aiter__(self) -> AsyncIterator[bytes]:
|
||||
if self._is_stream_consumed:
|
||||
raise StreamConsumed()
|
||||
|
||||
@ -66,8 +75,8 @@ class AsyncGeneratorStream:
|
||||
|
||||
|
||||
def encode_content(
|
||||
content: typing.Union[str, bytes, ByteStream]
|
||||
) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
|
||||
content: Union[str, bytes, ByteStream]
|
||||
) -> Tuple[Dict[str, str], ByteStream]:
|
||||
if isinstance(content, (str, bytes)):
|
||||
body = content.encode("utf-8") if isinstance(content, str) else content
|
||||
content_length = str(len(body))
|
||||
@ -75,7 +84,7 @@ def encode_content(
|
||||
stream = PlainByteStream(body)
|
||||
return headers, stream
|
||||
|
||||
elif isinstance(content, (typing.Iterable, typing.AsyncIterable)):
|
||||
elif isinstance(content, (Iterable, AsyncIterable)):
|
||||
headers = {"Transfer-Encoding": "chunked"}
|
||||
|
||||
# Generators should be wrapped in GeneratorStream/AsyncGeneratorStream
|
||||
@ -96,7 +105,7 @@ def encode_content(
|
||||
|
||||
def encode_urlencoded_data(
|
||||
data: dict,
|
||||
) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
|
||||
) -> Tuple[Dict[str, str], ByteStream]:
|
||||
body = urlencode(data, doseq=True).encode("utf-8")
|
||||
content_length = str(len(body))
|
||||
content_type = "application/x-www-form-urlencoded"
|
||||
@ -106,13 +115,29 @@ def encode_urlencoded_data(
|
||||
|
||||
def encode_multipart_data(
|
||||
data: dict, files: RequestFiles, boundary: bytes = None
|
||||
) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
|
||||
) -> Tuple[Dict[str, str], ByteStream]:
|
||||
stream = MultipartStream(data=data, files=files, boundary=boundary)
|
||||
headers = stream.get_headers()
|
||||
return headers, stream
|
||||
|
||||
|
||||
def encode_json(json: typing.Any) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
|
||||
def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]:
|
||||
body = text.encode("utf-8")
|
||||
content_length = str(len(body))
|
||||
content_type = "text/plain; charset=utf-8"
|
||||
headers = {"Content-Length": content_length, "Content-Type": content_type}
|
||||
return headers, PlainByteStream(body)
|
||||
|
||||
|
||||
def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]:
|
||||
body = html.encode("utf-8")
|
||||
content_length = str(len(body))
|
||||
content_type = "text/html; charset=utf-8"
|
||||
headers = {"Content-Length": content_length, "Content-Type": content_type}
|
||||
return headers, PlainByteStream(body)
|
||||
|
||||
|
||||
def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]:
|
||||
body = json_dumps(json).encode("utf-8")
|
||||
content_length = str(len(body))
|
||||
content_type = "application/json"
|
||||
@ -124,9 +149,9 @@ def encode_request(
|
||||
content: RequestContent = None,
|
||||
data: RequestData = None,
|
||||
files: RequestFiles = None,
|
||||
json: typing.Any = None,
|
||||
json: Any = None,
|
||||
boundary: bytes = None,
|
||||
) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
|
||||
) -> Tuple[Dict[str, str], ByteStream]:
|
||||
"""
|
||||
Handles encoding the given `content`, `data`, `files`, and `json`,
|
||||
returning a two-tuple of (<headers>, <stream>).
|
||||
@ -155,12 +180,21 @@ def encode_request(
|
||||
|
||||
def encode_response(
|
||||
content: ResponseContent = None,
|
||||
) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
|
||||
text: str = None,
|
||||
html: str = None,
|
||||
json: Any = None,
|
||||
) -> Tuple[Dict[str, str], ByteStream]:
|
||||
"""
|
||||
Handles encoding the given `content`, returning a two-tuple of
|
||||
(<headers>, <stream>).
|
||||
"""
|
||||
if content is not None:
|
||||
return encode_content(content)
|
||||
elif text is not None:
|
||||
return encode_text(text)
|
||||
elif html is not None:
|
||||
return encode_html(html)
|
||||
elif json is not None:
|
||||
return encode_json(json)
|
||||
|
||||
return {}, PlainByteStream(b"")
|
||||
|
||||
@ -704,11 +704,14 @@ class Response:
|
||||
self,
|
||||
status_code: int,
|
||||
*,
|
||||
request: Request = None,
|
||||
http_version: str = None,
|
||||
headers: HeaderTypes = None,
|
||||
content: ResponseContent = None,
|
||||
text: str = None,
|
||||
html: str = None,
|
||||
json: typing.Any = None,
|
||||
stream: ByteStream = None,
|
||||
http_version: str = None,
|
||||
request: Request = None,
|
||||
history: typing.List["Response"] = None,
|
||||
on_close: typing.Callable = None,
|
||||
):
|
||||
@ -740,7 +743,7 @@ class Response:
|
||||
# from the transport API.
|
||||
self.stream = stream
|
||||
else:
|
||||
headers, stream = encode_response(content)
|
||||
headers, stream = encode_response(content, text, html, json)
|
||||
self._prepare(headers)
|
||||
self.stream = stream
|
||||
if content is None or isinstance(content, bytes):
|
||||
|
||||
@ -5,7 +5,6 @@ Unit tests for auth classes also exist in tests/test_auth.py
|
||||
"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import typing
|
||||
@ -27,8 +26,7 @@ class App:
|
||||
def __call__(self, request: httpx.Request) -> httpx.Response:
|
||||
headers = {"www-authenticate": self.auth_header} if self.auth_header else {}
|
||||
data = {"auth": request.headers.get("Authorization")}
|
||||
content = json.dumps(data).encode("utf-8")
|
||||
return httpx.Response(self.status_code, headers=headers, content=content)
|
||||
return httpx.Response(self.status_code, headers=headers, json=data)
|
||||
|
||||
|
||||
class DigestApp:
|
||||
@ -50,8 +48,7 @@ class DigestApp:
|
||||
return self.challenge_send(request)
|
||||
|
||||
data = {"auth": request.headers.get("Authorization")}
|
||||
content = json.dumps(data).encode("utf-8")
|
||||
return httpx.Response(200, content=content)
|
||||
return httpx.Response(200, json=data)
|
||||
|
||||
def challenge_send(self, request: httpx.Request) -> httpx.Response:
|
||||
self._response_count += 1
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from http.cookiejar import Cookie, CookieJar
|
||||
|
||||
import httpx
|
||||
@ -8,8 +7,7 @@ from tests.utils import MockTransport
|
||||
def get_and_set_cookies(request: httpx.Request) -> httpx.Response:
|
||||
if request.url.path == "/echo_cookies":
|
||||
data = {"cookies": request.headers.get("cookie")}
|
||||
content = json.dumps(data).encode("utf-8")
|
||||
return httpx.Response(200, content=content)
|
||||
return httpx.Response(200, json=data)
|
||||
elif request.url.path == "/set_cookie":
|
||||
return httpx.Response(200, headers={"set-cookie": "example-name=example-value"})
|
||||
else:
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
@ -10,8 +8,7 @@ from tests.utils import MockTransport
|
||||
|
||||
def echo_headers(request: httpx.Request) -> httpx.Response:
|
||||
data = {"headers": dict(request.headers)}
|
||||
content = json.dumps(data).encode("utf-8")
|
||||
return httpx.Response(200, content=content)
|
||||
return httpx.Response(200, json=data)
|
||||
|
||||
|
||||
def test_client_header():
|
||||
|
||||
@ -3,7 +3,7 @@ from tests.utils import MockTransport
|
||||
|
||||
|
||||
def hello_world(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, content=b"Hello, world")
|
||||
return httpx.Response(200, text="Hello, world")
|
||||
|
||||
|
||||
def test_client_queryparams():
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
import json
|
||||
|
||||
import httpcore
|
||||
import pytest
|
||||
|
||||
@ -78,8 +76,11 @@ def redirects(request: httpx.Request) -> httpx.Response:
|
||||
|
||||
elif request.url.path == "/cross_domain_target":
|
||||
status_code = httpx.codes.OK
|
||||
content = json.dumps({"headers": dict(request.headers)}).encode("utf-8")
|
||||
return httpx.Response(status_code, content=content)
|
||||
data = {
|
||||
"body": request.content.decode("ascii"),
|
||||
"headers": dict(request.headers),
|
||||
}
|
||||
return httpx.Response(status_code, json=data)
|
||||
|
||||
elif request.url.path == "/redirect_body":
|
||||
status_code = httpx.codes.PERMANENT_REDIRECT
|
||||
@ -92,10 +93,11 @@ def redirects(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(status_code, headers=headers)
|
||||
|
||||
elif request.url.path == "/redirect_body_target":
|
||||
content = json.dumps(
|
||||
{"body": request.content.decode("ascii"), "headers": dict(request.headers)}
|
||||
).encode("utf-8")
|
||||
return httpx.Response(200, content=content)
|
||||
data = {
|
||||
"body": request.content.decode("ascii"),
|
||||
"headers": dict(request.headers),
|
||||
}
|
||||
return httpx.Response(200, json=data)
|
||||
|
||||
elif request.url.path == "/cross_subdomain":
|
||||
if request.headers["Host"] != "www.example.org":
|
||||
@ -103,7 +105,7 @@ def redirects(request: httpx.Request) -> httpx.Response:
|
||||
headers = {"location": "https://www.example.org/cross_subdomain"}
|
||||
return httpx.Response(status_code, headers=headers)
|
||||
else:
|
||||
return httpx.Response(200, content=b"Hello, world!")
|
||||
return httpx.Response(200, text="Hello, world!")
|
||||
|
||||
elif request.url.path == "/redirect_custom_scheme":
|
||||
status_code = httpx.codes.MOVED_PERMANENTLY
|
||||
@ -113,7 +115,7 @@ def redirects(request: httpx.Request) -> httpx.Response:
|
||||
if request.method == "HEAD":
|
||||
return httpx.Response(200)
|
||||
|
||||
return httpx.Response(200, content=b"Hello, world!")
|
||||
return httpx.Response(200, html="<html><body>Hello, world!</body></html>")
|
||||
|
||||
|
||||
def test_no_redirect():
|
||||
|
||||
@ -38,6 +38,48 @@ def test_response():
|
||||
assert not response.is_error
|
||||
|
||||
|
||||
def test_response_text():
|
||||
response = httpx.Response(200, text="Hello, world!")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.reason_phrase == "OK"
|
||||
assert response.text == "Hello, world!"
|
||||
assert response.headers == httpx.Headers(
|
||||
{
|
||||
"Content-Length": "13",
|
||||
"Content-Type": "text/plain; charset=utf-8",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_response_html():
|
||||
response = httpx.Response(200, html="<html><body>Hello, world!</html></body>")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.reason_phrase == "OK"
|
||||
assert response.text == "<html><body>Hello, world!</html></body>"
|
||||
assert response.headers == httpx.Headers(
|
||||
{
|
||||
"Content-Length": "39",
|
||||
"Content-Type": "text/html; charset=utf-8",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_response_json():
|
||||
response = httpx.Response(200, json={"hello": "world"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.reason_phrase == "OK"
|
||||
assert response.json() == {"hello": "world"}
|
||||
assert response.headers == httpx.Headers(
|
||||
{
|
||||
"Content-Length": "18",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_raise_for_status():
|
||||
request = httpx.Request("GET", "https://example.org")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user