Implement streaming multipart uploads (#857)
* Implement streaming multipart uploads * Tweak seekable check * Don't handle duplicate computation yet * Add memory test for multipart streaming * Lint * 1 pct is enough * Tweak lazy computations, fallback to non-streaming for broken filelikes * Reduce diff size * Drop memory test * Cleanup
This commit is contained in:
parent
af75908b7a
commit
5829ecb648
@ -1,8 +1,6 @@
|
||||
import binascii
|
||||
import mimetypes
|
||||
import os
|
||||
import typing
|
||||
from io import BytesIO
|
||||
from json import dumps as json_dumps
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlencode
|
||||
@ -11,7 +9,12 @@ import httpcore
|
||||
|
||||
from ._exceptions import StreamConsumed
|
||||
from ._types import StrOrBytes
|
||||
from ._utils import format_form_param
|
||||
from ._utils import (
|
||||
format_form_param,
|
||||
guess_content_type,
|
||||
peek_filelike_length,
|
||||
to_bytes,
|
||||
)
|
||||
|
||||
RequestData = typing.Union[
|
||||
dict, str, bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]
|
||||
@ -195,7 +198,7 @@ class URLEncodedStream(ContentStream):
|
||||
|
||||
class MultipartStream(ContentStream):
|
||||
"""
|
||||
Request content as multipart encoded form data.
|
||||
Request content as streaming multipart encoded form data.
|
||||
"""
|
||||
|
||||
class DataField:
|
||||
@ -212,15 +215,35 @@ class MultipartStream(ContentStream):
|
||||
self.value = value
|
||||
|
||||
def render_headers(self) -> bytes:
|
||||
name = format_form_param("name", self.name)
|
||||
return b"".join([b"Content-Disposition: form-data; ", name, b"\r\n\r\n"])
|
||||
if not hasattr(self, "_headers"):
|
||||
name = format_form_param("name", self.name)
|
||||
self._headers = b"".join(
|
||||
[b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
|
||||
)
|
||||
|
||||
return self._headers
|
||||
|
||||
def render_data(self) -> bytes:
|
||||
return (
|
||||
self.value
|
||||
if isinstance(self.value, bytes)
|
||||
else self.value.encode("utf-8")
|
||||
)
|
||||
if not hasattr(self, "_data"):
|
||||
self._data = (
|
||||
self.value
|
||||
if isinstance(self.value, bytes)
|
||||
else self.value.encode("utf-8")
|
||||
)
|
||||
|
||||
return self._data
|
||||
|
||||
def get_length(self) -> int:
|
||||
headers = self.render_headers()
|
||||
data = self.render_data()
|
||||
return len(headers) + len(data)
|
||||
|
||||
def can_replay(self) -> bool:
|
||||
return True
|
||||
|
||||
def render(self) -> typing.Iterator[bytes]:
|
||||
yield self.render_headers()
|
||||
yield self.render_data()
|
||||
|
||||
class FileField:
|
||||
"""
|
||||
@ -235,67 +258,88 @@ class MultipartStream(ContentStream):
|
||||
self.name = name
|
||||
if not isinstance(value, tuple):
|
||||
self.filename = Path(str(getattr(value, "name", "upload"))).name
|
||||
self.file = (
|
||||
value
|
||||
) # type: typing.Union[typing.IO[str], typing.IO[bytes]]
|
||||
self.content_type = self.guess_content_type()
|
||||
self.file: typing.Union[typing.IO[str], typing.IO[bytes]] = value
|
||||
self.content_type = guess_content_type(self.filename)
|
||||
else:
|
||||
self.filename = value[0]
|
||||
self.file = value[1]
|
||||
self.content_type = (
|
||||
value[2] if len(value) > 2 else self.guess_content_type()
|
||||
value[2] if len(value) > 2 else guess_content_type(self.filename)
|
||||
)
|
||||
|
||||
def guess_content_type(self) -> typing.Optional[str]:
|
||||
if self.filename:
|
||||
return (
|
||||
mimetypes.guess_type(self.filename)[0] or "application/octet-stream"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
def get_length(self) -> int:
|
||||
headers = self.render_headers()
|
||||
|
||||
if isinstance(self.file, str):
|
||||
return len(headers) + len(self.file)
|
||||
|
||||
# Let's do our best not to read `file` into memory.
|
||||
try:
|
||||
file_length = peek_filelike_length(self.file)
|
||||
except OSError:
|
||||
# As a last resort, read file and cache contents for later.
|
||||
assert not hasattr(self, "_data")
|
||||
self._data = to_bytes(self.file.read())
|
||||
file_length = len(self._data)
|
||||
|
||||
return len(headers) + file_length
|
||||
|
||||
def render_headers(self) -> bytes:
|
||||
parts = [
|
||||
b"Content-Disposition: form-data; ",
|
||||
format_form_param("name", self.name),
|
||||
]
|
||||
if self.filename:
|
||||
filename = format_form_param("filename", self.filename)
|
||||
parts.extend([b"; ", filename])
|
||||
if self.content_type is not None:
|
||||
content_type = self.content_type.encode()
|
||||
parts.extend([b"\r\nContent-Type: ", content_type])
|
||||
parts.append(b"\r\n\r\n")
|
||||
return b"".join(parts)
|
||||
if not hasattr(self, "_headers"):
|
||||
parts = [
|
||||
b"Content-Disposition: form-data; ",
|
||||
format_form_param("name", self.name),
|
||||
]
|
||||
if self.filename:
|
||||
filename = format_form_param("filename", self.filename)
|
||||
parts.extend([b"; ", filename])
|
||||
if self.content_type is not None:
|
||||
content_type = self.content_type.encode()
|
||||
parts.extend([b"\r\nContent-Type: ", content_type])
|
||||
parts.append(b"\r\n\r\n")
|
||||
self._headers = b"".join(parts)
|
||||
|
||||
def render_data(self) -> bytes:
|
||||
content: typing.Union[str, bytes]
|
||||
return self._headers
|
||||
|
||||
def render_data(self) -> typing.Iterator[bytes]:
|
||||
if isinstance(self.file, str):
|
||||
content = self.file
|
||||
else:
|
||||
content = self.file.read()
|
||||
return content.encode("utf-8") if isinstance(content, str) else content
|
||||
yield to_bytes(self.file)
|
||||
return
|
||||
|
||||
def __init__(self, data: dict, files: dict, boundary: bytes = None) -> None:
|
||||
body = BytesIO()
|
||||
if hasattr(self, "_data"):
|
||||
# Already rendered.
|
||||
yield self._data
|
||||
return
|
||||
|
||||
for chunk in self.file:
|
||||
yield to_bytes(chunk)
|
||||
|
||||
# Get ready for the next replay, if possible.
|
||||
if self.can_replay():
|
||||
assert self.file.seekable()
|
||||
self.file.seek(0)
|
||||
|
||||
def can_replay(self) -> bool:
|
||||
return True if isinstance(self.file, str) else self.file.seekable()
|
||||
|
||||
def render(self) -> typing.Iterator[bytes]:
|
||||
yield self.render_headers()
|
||||
yield from self.render_data()
|
||||
|
||||
def __init__(
|
||||
self, data: typing.Mapping, files: typing.Mapping, boundary: bytes = None
|
||||
) -> None:
|
||||
if boundary is None:
|
||||
boundary = binascii.hexlify(os.urandom(16))
|
||||
|
||||
for field in self.iter_fields(data, files):
|
||||
body.write(b"--%s\r\n" % boundary)
|
||||
body.write(field.render_headers())
|
||||
body.write(field.render_data())
|
||||
body.write(b"\r\n")
|
||||
|
||||
body.write(b"--%s--\r\n" % boundary)
|
||||
|
||||
self.boundary = boundary
|
||||
self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
|
||||
"ascii"
|
||||
)
|
||||
self.body = body.getvalue()
|
||||
self.fields = list(self._iter_fields(data, files))
|
||||
|
||||
def iter_fields(
|
||||
self, data: dict, files: dict
|
||||
def _iter_fields(
|
||||
self, data: typing.Mapping, files: typing.Mapping
|
||||
) -> typing.Iterator[typing.Union["FileField", "DataField"]]:
|
||||
for name, value in data.items():
|
||||
if isinstance(value, list):
|
||||
@ -307,16 +351,42 @@ class MultipartStream(ContentStream):
|
||||
for name, value in files.items():
|
||||
yield self.FileField(name=name, value=value)
|
||||
|
||||
def iter_chunks(self) -> typing.Iterator[bytes]:
|
||||
for field in self.fields:
|
||||
yield b"--%s\r\n" % self.boundary
|
||||
yield from field.render()
|
||||
yield b"\r\n"
|
||||
yield b"--%s--\r\n" % self.boundary
|
||||
|
||||
def iter_chunks_lengths(self) -> typing.Iterator[int]:
|
||||
boundary_length = len(self.boundary)
|
||||
# Follow closely what `.iter_chunks()` does.
|
||||
for field in self.fields:
|
||||
yield 2 + boundary_length + 2
|
||||
yield field.get_length()
|
||||
yield 2
|
||||
yield 2 + boundary_length + 4
|
||||
|
||||
def get_content_length(self) -> int:
|
||||
return sum(self.iter_chunks_lengths())
|
||||
|
||||
# Content stream interface.
|
||||
|
||||
def can_replay(self) -> bool:
|
||||
return all(field.can_replay() for field in self.fields)
|
||||
|
||||
def get_headers(self) -> typing.Dict[str, str]:
|
||||
content_length = str(len(self.body))
|
||||
content_length = str(self.get_content_length())
|
||||
content_type = self.content_type
|
||||
return {"Content-Length": content_length, "Content-Type": content_type}
|
||||
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
yield self.body
|
||||
for chunk in self.iter_chunks():
|
||||
yield chunk
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
yield self.body
|
||||
for chunk in self.iter_chunks():
|
||||
yield chunk
|
||||
|
||||
|
||||
def encode(
|
||||
|
||||
@ -2,6 +2,7 @@ import codecs
|
||||
import collections
|
||||
import contextlib
|
||||
import logging
|
||||
import mimetypes
|
||||
import netrc
|
||||
import os
|
||||
import re
|
||||
@ -310,6 +311,38 @@ def unquote(value: str) -> str:
|
||||
return value[1:-1] if value[0] == value[-1] == '"' else value
|
||||
|
||||
|
||||
def guess_content_type(filename: typing.Optional[str]) -> typing.Optional[str]:
|
||||
if filename:
|
||||
return mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
return None
|
||||
|
||||
|
||||
def peek_filelike_length(stream: typing.IO) -> int:
|
||||
"""
|
||||
Given a file-like stream object, return its length in number of bytes
|
||||
without reading it into memory.
|
||||
"""
|
||||
try:
|
||||
# Is it an actual file?
|
||||
fd = stream.fileno()
|
||||
except OSError:
|
||||
# No... Maybe it's something that supports random access, like `io.BytesIO`?
|
||||
try:
|
||||
# Assuming so, go to end of stream to figure out its length,
|
||||
# then put it back in place.
|
||||
offset = stream.tell()
|
||||
length = stream.seek(0, os.SEEK_END)
|
||||
stream.seek(offset)
|
||||
except OSError:
|
||||
# Not even that? Sorry, we're doomed...
|
||||
raise
|
||||
else:
|
||||
return length
|
||||
else:
|
||||
# Yup, seems to be an actual file.
|
||||
return os.fstat(fd).st_size
|
||||
|
||||
|
||||
def flatten_queryparams(
|
||||
queryparams: typing.Mapping[
|
||||
str, typing.Union["PrimitiveData", typing.Sequence["PrimitiveData"]]
|
||||
|
||||
@ -101,20 +101,27 @@ async def test_multipart_file_tuple():
|
||||
assert multipart["file"] == [b"<file content>"]
|
||||
|
||||
|
||||
def test_multipart_encode():
|
||||
def test_multipart_encode(tmp_path: typing.Any) -> None:
|
||||
path = str(tmp_path / "name.txt")
|
||||
with open(path, "wb") as f:
|
||||
f.write(b"<file content>")
|
||||
|
||||
data = {
|
||||
"a": "1",
|
||||
"b": b"C",
|
||||
"c": ["11", "22", "33"],
|
||||
"d": "",
|
||||
}
|
||||
files = {"file": ("name.txt", io.BytesIO(b"<file content>"))}
|
||||
files = {"file": ("name.txt", open(path, "rb"))}
|
||||
|
||||
with mock.patch("os.urandom", return_value=os.urandom(16)):
|
||||
boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
|
||||
|
||||
stream = encode(data=data, files=files)
|
||||
assert stream.can_replay()
|
||||
|
||||
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
|
||||
assert stream.body == (
|
||||
content = (
|
||||
'--{0}\r\nContent-Disposition: form-data; name="a"\r\n\r\n1\r\n'
|
||||
'--{0}\r\nContent-Disposition: form-data; name="b"\r\n\r\nC\r\n'
|
||||
'--{0}\r\nContent-Disposition: form-data; name="c"\r\n\r\n11\r\n'
|
||||
@ -127,17 +134,20 @@ def test_multipart_encode():
|
||||
"--{0}--\r\n"
|
||||
"".format(boundary).encode("ascii")
|
||||
)
|
||||
assert stream.get_headers()["Content-Length"] == str(len(content))
|
||||
assert b"".join(stream) == content
|
||||
|
||||
|
||||
def test_multipart_encode_files_allows_filenames_as_none():
|
||||
def test_multipart_encode_files_allows_filenames_as_none() -> None:
|
||||
files = {"file": (None, io.BytesIO(b"<file content>"))}
|
||||
with mock.patch("os.urandom", return_value=os.urandom(16)):
|
||||
boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
|
||||
|
||||
stream = encode(data={}, files=files)
|
||||
assert stream.can_replay()
|
||||
|
||||
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
|
||||
assert stream.body == (
|
||||
assert b"".join(stream) == (
|
||||
'--{0}\r\nContent-Disposition: form-data; name="file"\r\n\r\n'
|
||||
"<file content>\r\n--{0}--\r\n"
|
||||
"".format(boundary).encode("ascii")
|
||||
@ -153,16 +163,17 @@ def test_multipart_encode_files_allows_filenames_as_none():
|
||||
],
|
||||
)
|
||||
def test_multipart_encode_files_guesses_correct_content_type(
|
||||
file_name, expected_content_type
|
||||
):
|
||||
file_name: str, expected_content_type: str
|
||||
) -> None:
|
||||
files = {"file": (file_name, io.BytesIO(b"<file content>"))}
|
||||
with mock.patch("os.urandom", return_value=os.urandom(16)):
|
||||
boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
|
||||
|
||||
stream = encode(data={}, files=files)
|
||||
assert stream.can_replay()
|
||||
|
||||
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
|
||||
assert stream.body == (
|
||||
assert b"".join(stream) == (
|
||||
f'--{boundary}\r\nContent-Disposition: form-data; name="file"; '
|
||||
f'filename="{file_name}"\r\nContent-Type: '
|
||||
f"{expected_content_type}\r\n\r\n<file content>\r\n--{boundary}--\r\n"
|
||||
@ -170,21 +181,64 @@ def test_multipart_encode_files_guesses_correct_content_type(
|
||||
)
|
||||
|
||||
|
||||
def test_multipart_encode_files_allows_str_content():
|
||||
def test_multipart_encode_files_allows_str_content() -> None:
|
||||
files = {"file": ("test.txt", "<string content>", "text/plain")}
|
||||
with mock.patch("os.urandom", return_value=os.urandom(16)):
|
||||
boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
|
||||
|
||||
stream = encode(data={}, files=files)
|
||||
assert stream.can_replay()
|
||||
|
||||
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
|
||||
assert stream.body == (
|
||||
content = (
|
||||
'--{0}\r\nContent-Disposition: form-data; name="file"; '
|
||||
'filename="test.txt"\r\n'
|
||||
"Content-Type: text/plain\r\n\r\n<string content>\r\n"
|
||||
"--{0}--\r\n"
|
||||
"".format(boundary).encode("ascii")
|
||||
)
|
||||
assert stream.get_headers()["Content-Length"] == str(len(content))
|
||||
assert b"".join(stream) == content
|
||||
|
||||
|
||||
def test_multipart_encode_non_seekable_filelike() -> None:
|
||||
"""
|
||||
Test that special readable but non-seekable filelike objects are supported,
|
||||
at the cost of reading them into memory at most once.
|
||||
"""
|
||||
|
||||
class IteratorIO(io.IOBase):
|
||||
def __init__(self, iterator: typing.Iterator[bytes]) -> None:
|
||||
self._iterator = iterator
|
||||
|
||||
def seekable(self) -> bool:
|
||||
return False
|
||||
|
||||
def read(self, *args: typing.Any) -> bytes:
|
||||
return b"".join(self._iterator)
|
||||
|
||||
def data() -> typing.Iterator[bytes]:
|
||||
yield b"Hello"
|
||||
yield b"World"
|
||||
|
||||
fileobj = IteratorIO(data())
|
||||
files = {"file": fileobj}
|
||||
stream = encode(files=files, boundary=b"+++")
|
||||
assert not stream.can_replay()
|
||||
|
||||
content = (
|
||||
b"--+++\r\n"
|
||||
b'Content-Disposition: form-data; name="file"; filename="upload"\r\n'
|
||||
b"Content-Type: application/octet-stream\r\n"
|
||||
b"\r\n"
|
||||
b"HelloWorld\r\n"
|
||||
b"--+++--\r\n"
|
||||
)
|
||||
assert stream.get_headers() == {
|
||||
"Content-Type": "multipart/form-data; boundary=+++",
|
||||
"Content-Length": str(len(content)),
|
||||
}
|
||||
assert b"".join(stream) == content
|
||||
|
||||
|
||||
class TestHeaderParamHTML5Formatting:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user