Support Response(text=...), Response(html=...), Response(json=...) (#1297)

* Refactor content_streams internally

* Tidy up multipart

* Use ByteStream annotation internally

* Support Response(text=...), Response(html=...), Response(json=...)

* Add tests for Response(text=..., html=..., json=...)
This commit is contained in:
Tom Christie 2020-09-21 11:19:19 +01:00 committed by GitHub
parent 8ee08afe96
commit f3c29416f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 115 additions and 42 deletions

View File

@ -1,6 +1,15 @@
import inspect
import typing
from json import dumps as json_dumps
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Dict,
Iterable,
Iterator,
Tuple,
Union,
)
from urllib.parse import urlencode
from ._exceptions import StreamConsumed
@ -22,10 +31,10 @@ class PlainByteStream:
def __init__(self, body: bytes) -> None:
self._body = body
def __iter__(self) -> typing.Iterator[bytes]:
def __iter__(self) -> Iterator[bytes]:
yield self._body
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async def __aiter__(self) -> AsyncIterator[bytes]:
yield self._body
@ -34,11 +43,11 @@ class GeneratorStream:
Request content encoded as plain bytes, using an byte generator.
"""
def __init__(self, generator: typing.Iterable[bytes]) -> None:
def __init__(self, generator: Iterable[bytes]) -> None:
self._generator = generator
self._is_stream_consumed = False
def __iter__(self) -> typing.Iterator[bytes]:
def __iter__(self) -> Iterator[bytes]:
if self._is_stream_consumed:
raise StreamConsumed()
@ -52,11 +61,11 @@ class AsyncGeneratorStream:
Request content encoded as plain bytes, using an async byte iterator.
"""
def __init__(self, agenerator: typing.AsyncIterable[bytes]) -> None:
def __init__(self, agenerator: AsyncIterable[bytes]) -> None:
self._agenerator = agenerator
self._is_stream_consumed = False
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async def __aiter__(self) -> AsyncIterator[bytes]:
if self._is_stream_consumed:
raise StreamConsumed()
@ -66,8 +75,8 @@ class AsyncGeneratorStream:
def encode_content(
content: typing.Union[str, bytes, ByteStream]
) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
content: Union[str, bytes, ByteStream]
) -> Tuple[Dict[str, str], ByteStream]:
if isinstance(content, (str, bytes)):
body = content.encode("utf-8") if isinstance(content, str) else content
content_length = str(len(body))
@ -75,7 +84,7 @@ def encode_content(
stream = PlainByteStream(body)
return headers, stream
elif isinstance(content, (typing.Iterable, typing.AsyncIterable)):
elif isinstance(content, (Iterable, AsyncIterable)):
headers = {"Transfer-Encoding": "chunked"}
# Generators should be wrapped in GeneratorStream/AsyncGeneratorStream
@ -96,7 +105,7 @@ def encode_content(
def encode_urlencoded_data(
data: dict,
) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
) -> Tuple[Dict[str, str], ByteStream]:
body = urlencode(data, doseq=True).encode("utf-8")
content_length = str(len(body))
content_type = "application/x-www-form-urlencoded"
@ -106,13 +115,29 @@ def encode_urlencoded_data(
def encode_multipart_data(
data: dict, files: RequestFiles, boundary: bytes = None
) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
) -> Tuple[Dict[str, str], ByteStream]:
stream = MultipartStream(data=data, files=files, boundary=boundary)
headers = stream.get_headers()
return headers, stream
def encode_json(json: typing.Any) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]:
body = text.encode("utf-8")
content_length = str(len(body))
content_type = "text/plain; charset=utf-8"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, PlainByteStream(body)
def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]:
body = html.encode("utf-8")
content_length = str(len(body))
content_type = "text/html; charset=utf-8"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, PlainByteStream(body)
def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]:
body = json_dumps(json).encode("utf-8")
content_length = str(len(body))
content_type = "application/json"
@ -124,9 +149,9 @@ def encode_request(
content: RequestContent = None,
data: RequestData = None,
files: RequestFiles = None,
json: typing.Any = None,
json: Any = None,
boundary: bytes = None,
) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
) -> Tuple[Dict[str, str], ByteStream]:
"""
Handles encoding the given `content`, `data`, `files`, and `json`,
returning a two-tuple of (<headers>, <stream>).
@ -155,12 +180,21 @@ def encode_request(
def encode_response(
content: ResponseContent = None,
) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
text: str = None,
html: str = None,
json: Any = None,
) -> Tuple[Dict[str, str], ByteStream]:
"""
Handles encoding the given `content`, returning a two-tuple of
(<headers>, <stream>).
"""
if content is not None:
return encode_content(content)
elif text is not None:
return encode_text(text)
elif html is not None:
return encode_html(html)
elif json is not None:
return encode_json(json)
return {}, PlainByteStream(b"")

View File

@ -704,11 +704,14 @@ class Response:
self,
status_code: int,
*,
request: Request = None,
http_version: str = None,
headers: HeaderTypes = None,
content: ResponseContent = None,
text: str = None,
html: str = None,
json: typing.Any = None,
stream: ByteStream = None,
http_version: str = None,
request: Request = None,
history: typing.List["Response"] = None,
on_close: typing.Callable = None,
):
@ -740,7 +743,7 @@ class Response:
# from the transport API.
self.stream = stream
else:
headers, stream = encode_response(content)
headers, stream = encode_response(content, text, html, json)
self._prepare(headers)
self.stream = stream
if content is None or isinstance(content, bytes):

View File

@ -5,7 +5,6 @@ Unit tests for auth classes also exist in tests/test_auth.py
"""
import asyncio
import hashlib
import json
import os
import threading
import typing
@ -27,8 +26,7 @@ class App:
def __call__(self, request: httpx.Request) -> httpx.Response:
headers = {"www-authenticate": self.auth_header} if self.auth_header else {}
data = {"auth": request.headers.get("Authorization")}
content = json.dumps(data).encode("utf-8")
return httpx.Response(self.status_code, headers=headers, content=content)
return httpx.Response(self.status_code, headers=headers, json=data)
class DigestApp:
@ -50,8 +48,7 @@ class DigestApp:
return self.challenge_send(request)
data = {"auth": request.headers.get("Authorization")}
content = json.dumps(data).encode("utf-8")
return httpx.Response(200, content=content)
return httpx.Response(200, json=data)
def challenge_send(self, request: httpx.Request) -> httpx.Response:
self._response_count += 1

View File

@ -1,4 +1,3 @@
import json
from http.cookiejar import Cookie, CookieJar
import httpx
@ -8,8 +7,7 @@ from tests.utils import MockTransport
def get_and_set_cookies(request: httpx.Request) -> httpx.Response:
if request.url.path == "/echo_cookies":
data = {"cookies": request.headers.get("cookie")}
content = json.dumps(data).encode("utf-8")
return httpx.Response(200, content=content)
return httpx.Response(200, json=data)
elif request.url.path == "/set_cookie":
return httpx.Response(200, headers={"set-cookie": "example-name=example-value"})
else:

View File

@ -1,7 +1,5 @@
#!/usr/bin/env python3
import json
import pytest
import httpx
@ -10,8 +8,7 @@ from tests.utils import MockTransport
def echo_headers(request: httpx.Request) -> httpx.Response:
data = {"headers": dict(request.headers)}
content = json.dumps(data).encode("utf-8")
return httpx.Response(200, content=content)
return httpx.Response(200, json=data)
def test_client_header():

View File

@ -3,7 +3,7 @@ from tests.utils import MockTransport
def hello_world(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, content=b"Hello, world")
return httpx.Response(200, text="Hello, world")
def test_client_queryparams():

View File

@ -1,5 +1,3 @@
import json
import httpcore
import pytest
@ -78,8 +76,11 @@ def redirects(request: httpx.Request) -> httpx.Response:
elif request.url.path == "/cross_domain_target":
status_code = httpx.codes.OK
content = json.dumps({"headers": dict(request.headers)}).encode("utf-8")
return httpx.Response(status_code, content=content)
data = {
"body": request.content.decode("ascii"),
"headers": dict(request.headers),
}
return httpx.Response(status_code, json=data)
elif request.url.path == "/redirect_body":
status_code = httpx.codes.PERMANENT_REDIRECT
@ -92,10 +93,11 @@ def redirects(request: httpx.Request) -> httpx.Response:
return httpx.Response(status_code, headers=headers)
elif request.url.path == "/redirect_body_target":
content = json.dumps(
{"body": request.content.decode("ascii"), "headers": dict(request.headers)}
).encode("utf-8")
return httpx.Response(200, content=content)
data = {
"body": request.content.decode("ascii"),
"headers": dict(request.headers),
}
return httpx.Response(200, json=data)
elif request.url.path == "/cross_subdomain":
if request.headers["Host"] != "www.example.org":
@ -103,7 +105,7 @@ def redirects(request: httpx.Request) -> httpx.Response:
headers = {"location": "https://www.example.org/cross_subdomain"}
return httpx.Response(status_code, headers=headers)
else:
return httpx.Response(200, content=b"Hello, world!")
return httpx.Response(200, text="Hello, world!")
elif request.url.path == "/redirect_custom_scheme":
status_code = httpx.codes.MOVED_PERMANENTLY
@ -113,7 +115,7 @@ def redirects(request: httpx.Request) -> httpx.Response:
if request.method == "HEAD":
return httpx.Response(200)
return httpx.Response(200, content=b"Hello, world!")
return httpx.Response(200, html="<html><body>Hello, world!</body></html>")
def test_no_redirect():

View File

@ -38,6 +38,48 @@ def test_response():
assert not response.is_error
def test_response_text():
response = httpx.Response(200, text="Hello, world!")
assert response.status_code == 200
assert response.reason_phrase == "OK"
assert response.text == "Hello, world!"
assert response.headers == httpx.Headers(
{
"Content-Length": "13",
"Content-Type": "text/plain; charset=utf-8",
}
)
def test_response_html():
response = httpx.Response(200, html="<html><body>Hello, world!</html></body>")
assert response.status_code == 200
assert response.reason_phrase == "OK"
assert response.text == "<html><body>Hello, world!</html></body>"
assert response.headers == httpx.Headers(
{
"Content-Length": "39",
"Content-Type": "text/html; charset=utf-8",
}
)
def test_response_json():
response = httpx.Response(200, json={"hello": "world"})
assert response.status_code == 200
assert response.reason_phrase == "OK"
assert response.json() == {"hello": "world"}
assert response.headers == httpx.Headers(
{
"Content-Length": "18",
"Content-Type": "application/json",
}
)
def test_raise_for_status():
request = httpx.Request("GET", "https://example.org")