Typing: use wsgiref.types to validate types and fix issues uncovered (#2467)

* Typing: use wsgiref.types to validate types and fix issues uncovered

- start_response() must return a write(bytes) function, even though this
  is now deprecated. It's fine to be a no-op here.
- sys.exc_info() can return (None, None, None), so make sure to handle that case.

* remove typing_extensions

Co-authored-by: Martijn Pieters <mj@zopatista.com>
This commit is contained in:
Adrian Garcia Badaracco 2022-11-29 11:55:21 -06:00 committed by GitHub
parent 8327e13454
commit 1ff67ea47c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,14 +1,31 @@
import io
import itertools
import sys
import types
import typing
from .._models import Request, Response
from .._types import SyncByteStream
from .base import BaseTransport
_T = typing.TypeVar("_T")
_ExcInfo = typing.Tuple[typing.Type[BaseException], BaseException, types.TracebackType]
_OptExcInfo = typing.Union[_ExcInfo, typing.Tuple[None, None, None]]
def _skip_leading_empty_chunks(body: typing.Iterable[bytes]) -> typing.Iterable[bytes]:
# backported wsgiref.types definitions from Python 3.11
StartResponse = typing.Callable[
[str, typing.List[typing.Tuple[str, str]], typing.Optional[_OptExcInfo]],
typing.Callable[[bytes], object],
]
WSGIApplication = typing.Callable[
[typing.Dict[str, typing.Any], StartResponse], typing.Iterable[bytes]
]
def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]:
body = iter(body)
for chunk in body:
if chunk:
@ -54,7 +71,7 @@ class WSGITransport(BaseTransport):
Arguments:
* `app` - The ASGI application.
* `app` - The WSGI application.
* `raise_app_exceptions` - Boolean indicating if exceptions in the application
should be raised. Default to `True`. Can be set to `False` for use cases
such as testing the content of a client 500 response.
@ -65,7 +82,7 @@ class WSGITransport(BaseTransport):
def __init__(
self,
app: typing.Callable[..., typing.Any],
app: WSGIApplication,
raise_app_exceptions: bool = True,
script_name: str = "",
remote_addr: str = "127.0.0.1",
@ -111,12 +128,13 @@ class WSGITransport(BaseTransport):
def start_response(
status: str,
response_headers: typing.List[typing.Tuple[str, str]],
exc_info: typing.Any = None,
) -> None:
exc_info: typing.Optional[_OptExcInfo] = None,
) -> typing.Callable[[bytes], typing.Any]:
nonlocal seen_status, seen_response_headers, seen_exc_info
seen_status = status
seen_response_headers = response_headers
seen_exc_info = exc_info
return lambda _: None
result = self.app(environ, start_response)
@ -124,7 +142,7 @@ class WSGITransport(BaseTransport):
assert seen_status is not None
assert seen_response_headers is not None
if seen_exc_info and self.raise_app_exceptions:
if seen_exc_info and seen_exc_info[0] and self.raise_app_exceptions:
raise seen_exc_info[1]
status_code = int(seen_status.split()[0])