Allow custom headers in multipart/form-data requests (#1936)

* feat: allow passing multipart headers

* Add test for including content-type in headers

* lint

* override content_type with headers

* compare tuples based on length

* incorporate suggestion

* remove .title() on headers
This commit is contained in:
Adrian Garcia Badaracco 2022-01-13 00:49:14 -08:00 committed by GitHub
parent 3eaf69a772
commit 0f1ff50a1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 80 additions and 8 deletions

View File

@ -78,23 +78,41 @@ class FileField:
fileobj: FileContent
headers: typing.Dict[str, str] = {}
content_type: typing.Optional[str] = None
# This large tuple based API largely mirror's requests' API
# It would be good to think of better APIs for this that we could include in httpx 2.0
# since variable length tuples (especially of 4 elements) are quite unwieldly
if isinstance(value, tuple):
try:
filename, fileobj, content_type = value # type: ignore
except ValueError:
if len(value) == 2:
# neither the 3rd parameter (content_type) nor the 4th (headers) was included
filename, fileobj = value # type: ignore
content_type = guess_content_type(filename)
elif len(value) == 3:
filename, fileobj, content_type = value # type: ignore
else:
# all 4 parameters included
filename, fileobj, content_type, headers = value # type: ignore
else:
filename = Path(str(getattr(value, "name", "upload"))).name
fileobj = value
if content_type is None:
content_type = guess_content_type(filename)
has_content_type_header = any("content-type" in key.lower() for key in headers)
if content_type is not None and not has_content_type_header:
# note that unlike requests, we ignore the content_type
# provided in the 3rd tuple element if it is also included in the headers
# requests does the opposite (it overwrites the header with the 3rd tuple element)
headers["Content-Type"] = content_type
if isinstance(fileobj, (str, io.StringIO)):
raise TypeError(f"Expected bytes or bytes-like object got: {type(fileobj)}")
self.filename = filename
self.file = fileobj
self.content_type = content_type
self.headers = headers
self._consumed = False
def get_length(self) -> int:
@ -122,9 +140,9 @@ class FileField:
if self.filename:
filename = format_form_param("filename", self.filename)
parts.extend([b"; ", filename])
if self.content_type is not None:
content_type = self.content_type.encode()
parts.extend([b"\r\nContent-Type: ", content_type])
for header_name, header_value in self.headers.items():
key, val = f"\r\n{header_name}: ".encode(), header_value.encode()
parts.extend([key, val])
parts.append(b"\r\n\r\n")
self._headers = b"".join(parts)

View File

@ -89,6 +89,8 @@ FileTypes = Union[
Tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], FileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
]
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]

View File

@ -94,6 +94,58 @@ def test_multipart_file_tuple():
assert multipart["file"] == [b"<file content>"]
@pytest.mark.parametrize("content_type", [None, "text/plain"])
def test_multipart_file_tuple_headers(content_type: typing.Optional[str]):
file_name = "test.txt"
expected_content_type = "text/plain"
headers = {"Expires": "0"}
files = {"file": (file_name, io.BytesIO(b"<file content>"), content_type, headers)}
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = os.urandom(16).hex()
headers, stream = encode_request(data={}, files=files)
assert isinstance(stream, typing.Iterable)
content = (
f'--{boundary}\r\nContent-Disposition: form-data; name="file"; '
f'filename="{file_name}"\r\nExpires: 0\r\nContent-Type: '
f"{expected_content_type}\r\n\r\n<file content>\r\n--{boundary}--\r\n"
"".encode("ascii")
)
assert headers == {
"Content-Type": f"multipart/form-data; boundary={boundary}",
"Content-Length": str(len(content)),
}
assert content == b"".join(stream)
def test_multipart_headers_include_content_type() -> None:
"""Content-Type from 4th tuple parameter (headers) should override the 3rd parameter (content_type)"""
file_name = "test.txt"
expected_content_type = "image/png"
headers = {"Content-Type": "image/png"}
files = {"file": (file_name, io.BytesIO(b"<file content>"), "text_plain", headers)}
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = os.urandom(16).hex()
headers, stream = encode_request(data={}, files=files)
assert isinstance(stream, typing.Iterable)
content = (
f'--{boundary}\r\nContent-Disposition: form-data; name="file"; '
f'filename="{file_name}"\r\nContent-Type: '
f"{expected_content_type}\r\n\r\n<file content>\r\n--{boundary}--\r\n"
"".encode("ascii")
)
assert headers == {
"Content-Type": f"multipart/form-data; boundary={boundary}",
"Content-Length": str(len(content)),
}
assert content == b"".join(stream)
def test_multipart_encode(tmp_path: typing.Any) -> None:
path = str(tmp_path / "name.txt")
with open(path, "wb") as f: