allow setting an explicit multipart boundary via headers (#2278)
This commit is contained in:
parent
2434e650ee
commit
1526048c94
@ -150,7 +150,7 @@ def encode_urlencoded_data(
|
||||
|
||||
|
||||
def encode_multipart_data(
|
||||
data: dict, files: RequestFiles, boundary: Optional[bytes] = None
|
||||
data: dict, files: RequestFiles, boundary: Optional[bytes]
|
||||
) -> Tuple[Dict[str, str], MultipartStream]:
|
||||
multipart = MultipartStream(data=data, files=files, boundary=boundary)
|
||||
headers = multipart.get_headers()
|
||||
|
||||
@ -27,6 +27,7 @@ from ._exceptions import (
|
||||
StreamConsumed,
|
||||
request_context,
|
||||
)
|
||||
from ._multipart import get_multipart_boundary_from_content_type
|
||||
from ._status_codes import codes
|
||||
from ._types import (
|
||||
AsyncByteStream,
|
||||
@ -332,7 +333,18 @@ class Request:
|
||||
Cookies(cookies).set_cookie_header(self)
|
||||
|
||||
if stream is None:
|
||||
headers, stream = encode_request(content, data, files, json)
|
||||
content_type: typing.Optional[str] = self.headers.get("content-type")
|
||||
headers, stream = encode_request(
|
||||
content=content,
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
boundary=get_multipart_boundary_from_content_type(
|
||||
content_type=content_type.encode(self.headers.encoding)
|
||||
if content_type
|
||||
else None
|
||||
),
|
||||
)
|
||||
self._prepare(headers)
|
||||
self.stream = stream
|
||||
# Load the request body, except for streaming content.
|
||||
|
||||
@ -20,6 +20,20 @@ from ._utils import (
|
||||
)
|
||||
|
||||
|
||||
def get_multipart_boundary_from_content_type(
|
||||
content_type: typing.Optional[bytes],
|
||||
) -> typing.Optional[bytes]:
|
||||
if not content_type or not content_type.startswith(b"multipart/form-data"):
|
||||
return None
|
||||
# parse boundary according to
|
||||
# https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1
|
||||
if b";" in content_type:
|
||||
for section in content_type.split(b";"):
|
||||
if section.strip().lower().startswith(b"boundary="):
|
||||
return section.strip()[len(b"boundary=") :].strip(b'"')
|
||||
return None
|
||||
|
||||
|
||||
class DataField:
|
||||
"""
|
||||
A single form field item, within a multipart form field.
|
||||
|
||||
@ -42,6 +42,58 @@ def test_multipart(value, output):
|
||||
assert multipart["file"] == [b"<file content>"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"header",
|
||||
[
|
||||
"multipart/form-data; boundary=+++; charset=utf-8",
|
||||
"multipart/form-data; charset=utf-8; boundary=+++",
|
||||
"multipart/form-data; boundary=+++",
|
||||
"multipart/form-data; boundary=+++ ;",
|
||||
'multipart/form-data; boundary="+++"; charset=utf-8',
|
||||
'multipart/form-data; charset=utf-8; boundary="+++"',
|
||||
'multipart/form-data; boundary="+++"',
|
||||
'multipart/form-data; boundary="+++" ;',
|
||||
],
|
||||
)
|
||||
def test_multipart_explicit_boundary(header: str) -> None:
|
||||
client = httpx.Client(transport=httpx.MockTransport(echo_request_content))
|
||||
|
||||
files = {"file": io.BytesIO(b"<file content>")}
|
||||
headers = {"content-type": header}
|
||||
response = client.post("http://127.0.0.1:8000/", files=files, headers=headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
# We're using the cgi module to verify the behavior here, which is a
|
||||
# bit grungy, but sufficient just for our testing purposes.
|
||||
assert response.request.headers["Content-Type"] == header
|
||||
content_length = response.request.headers["Content-Length"]
|
||||
pdict: dict = {
|
||||
"boundary": b"+++",
|
||||
"CONTENT-LENGTH": content_length,
|
||||
}
|
||||
multipart = cgi.parse_multipart(io.BytesIO(response.content), pdict)
|
||||
|
||||
assert multipart["file"] == [b"<file content>"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"header",
|
||||
[
|
||||
"multipart/form-data; charset=utf-8",
|
||||
"multipart/form-data; charset=utf-8; ",
|
||||
],
|
||||
)
|
||||
def test_multipart_header_without_boundary(header: str) -> None:
|
||||
client = httpx.Client(transport=httpx.MockTransport(echo_request_content))
|
||||
|
||||
files = {"file": io.BytesIO(b"<file content>")}
|
||||
headers = {"content-type": header}
|
||||
response = client.post("http://127.0.0.1:8000/", files=files, headers=headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.request.headers["Content-Type"] == header
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("key"), (b"abc", 1, 2.3, None))
|
||||
def test_multipart_invalid_key(key):
|
||||
client = httpx.Client(transport=httpx.MockTransport(echo_request_content))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user