Fix bytes support in multipart uploads (#974)

This commit is contained in:
Florimond Manca 2020-05-21 16:25:31 +02:00 committed by GitHub
parent a58be59adb
commit ab9ace2749
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 7 deletions

View File

@ -247,7 +247,7 @@ class MultipartStream(ContentStream):
def get_length(self) -> int:
headers = self.render_headers()
if isinstance(self.file, str):
if isinstance(self.file, (str, bytes)):
return len(headers) + len(self.file)
# Let's do our best not to read `file` into memory.
@ -279,7 +279,7 @@ class MultipartStream(ContentStream):
return self._headers
def render_data(self) -> typing.Iterator[bytes]:
if isinstance(self.file, str):
if isinstance(self.file, (str, bytes)):
yield to_bytes(self.file)
return
@ -297,7 +297,7 @@ class MultipartStream(ContentStream):
self.file.seek(0)
def can_replay(self) -> bool:
return True if isinstance(self.file, str) else self.file.seekable()
return True if isinstance(self.file, (str, bytes)) else self.file.seekable()
def render(self) -> typing.Iterator[bytes]:
yield self.render_headers()

View File

@ -181,8 +181,14 @@ def test_multipart_encode_files_guesses_correct_content_type(
)
def test_multipart_encode_files_allows_str_content() -> None:
files = {"file": ("test.txt", "<string content>", "text/plain")}
@pytest.mark.parametrize(
"value, output",
((b"<bytes content>", "<bytes content>"), ("<string content>", "<string content>")),
)
def test_multipart_encode_files_allows_bytes_or_str_content(
value: typing.Union[str, bytes], output: str
) -> None:
files = {"file": ("test.txt", value, "text/plain")}
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
@ -193,9 +199,9 @@ def test_multipart_encode_files_allows_str_content() -> None:
content = (
'--{0}\r\nContent-Disposition: form-data; name="file"; '
'filename="test.txt"\r\n'
"Content-Type: text/plain\r\n\r\n<string content>\r\n"
"Content-Type: text/plain\r\n\r\n{1}\r\n"
"--{0}--\r\n"
"".format(boundary).encode("ascii")
"".format(boundary, output).encode("ascii")
)
assert stream.get_headers()["Content-Length"] == str(len(content))
assert b"".join(stream) == content