Tighten multipart implementation types (#975)

This commit is contained in:
Florimond Manca 2020-05-21 17:41:36 +02:00 committed by GitHub
parent f1b3d74abb
commit 9f58afd8f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 18 deletions

View File

@ -8,7 +8,7 @@ from urllib.parse import urlencode
import httpcore
from ._exceptions import StreamConsumed
from ._types import RequestData, RequestFiles
from ._types import FileContent, FileTypes, RequestData, RequestFiles
from ._utils import (
format_form_param,
guess_content_type,
@ -227,22 +227,25 @@ class MultipartStream(ContentStream):
A single file field item, within a multipart form field.
"""
def __init__(
self,
name: str,
value: typing.Union[typing.IO[str], typing.IO[bytes], tuple],
) -> None:
def __init__(self, name: str, value: FileTypes) -> None:
self.name = name
if not isinstance(value, tuple):
self.filename = Path(str(getattr(value, "name", "upload"))).name
self.file: typing.Union[typing.IO[str], typing.IO[bytes]] = value
self.content_type = guess_content_type(self.filename)
fileobj: FileContent
if isinstance(value, tuple):
try:
filename, fileobj, content_type = value # type: ignore
except ValueError:
filename, fileobj = value # type: ignore
content_type = guess_content_type(filename)
else:
self.filename = value[0]
self.file = value[1]
self.content_type = (
value[2] if len(value) > 2 else guess_content_type(self.filename)
)
filename = Path(str(getattr(value, "name", "upload"))).name
fileobj = value
content_type = guess_content_type(filename)
self.filename = filename
self.file = fileobj
self.content_type = content_type
def get_length(self) -> int:
headers = self.render_headers()
@ -304,7 +307,7 @@ class MultipartStream(ContentStream):
yield from self.render_data()
def __init__(
self, data: typing.Mapping, files: typing.Mapping, boundary: bytes = None
self, data: typing.Mapping, files: RequestFiles, boundary: bytes = None
) -> None:
if boundary is None:
boundary = binascii.hexlify(os.urandom(16))
@ -316,7 +319,7 @@ class MultipartStream(ContentStream):
self.fields = list(self._iter_fields(data, files))
def _iter_fields(
self, data: typing.Mapping, files: typing.Mapping
self, data: typing.Mapping, files: RequestFiles
) -> typing.Iterator[typing.Union["FileField", "DataField"]]:
for name, value in data.items():
if isinstance(value, list):

View File

@ -11,7 +11,7 @@ combine_as_imports = True
force_grid_wrap = 0
include_trailing_comma = True
known_first_party = httpx,tests
known_third_party = brotli,certifi,chardet,cryptography,hstspreload,httpcore,pytest,rfc3986,setuptools,sniffio,trio,trustme,urllib3,uvicorn
known_third_party = brotli,certifi,chardet,cryptography,hstspreload,httpcore,pytest,rfc3986,setuptools,sniffio,trio,trustme,uvicorn
line_length = 88
multi_line_output = 3