Typing: always fill in generic type parameters (#2468)

* Typing: always fill in generic type parameters

Being explicit about the parameters helps find bugs and makes the library
easier to use for users.

- Tell mypy to disallow generics without parameter values
- Give all generic types parameters values

* fix things that aren't coming in from other commits

* lint

Co-authored-by: Martijn Pieters <mj@zopatista.com>
Co-authored-by: Tom Christie <tom@tomchristie.com>
This commit is contained in:
Adrian Garcia Badaracco 2022-11-29 10:36:03 -06:00 committed by GitHub
parent 16e2830624
commit 1b4e7fbb48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 39 additions and 18 deletions

View File

@ -10,6 +10,9 @@ from ._exceptions import ProtocolError
from ._models import Request, Response
from ._utils import to_bytes, to_str, unquote
if typing.TYPE_CHECKING: # pragma: no cover
from hashlib import _Hash
class Auth:
"""
@ -139,7 +142,7 @@ class BasicAuth(Auth):
class DigestAuth(Auth):
_ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable] = {
_ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable[[bytes], "_Hash"]] = {
"MD5": hashlib.md5,
"MD5-SESS": hashlib.md5,
"SHA": hashlib.sha1,

View File

@ -179,7 +179,12 @@ def print_response(response: Response) -> None:
console.print(f"<{len(response.content)} bytes of binary data>")
def format_certificate(cert: dict) -> str: # pragma: no cover
_PCTRTT = typing.Tuple[typing.Tuple[str, str], ...]
_PCTRTTT = typing.Tuple[_PCTRTT, ...]
_PeerCertRetDictType = typing.Dict[str, typing.Union[str, _PCTRTTT, _PCTRTT]]
def format_certificate(cert: _PeerCertRetDictType) -> str: # pragma: no cover
lines = []
for key, value in cert.items():
if isinstance(value, (list, tuple)):

View File

@ -3,7 +3,7 @@ import email.message
import json as jsonlib
import typing
import urllib.request
from collections.abc import Mapping, MutableMapping
from collections.abc import Mapping
from http.cookiejar import Cookie, CookieJar
from ._content import ByteStream, UnattachedStream, encode_request, encode_response
@ -1002,7 +1002,7 @@ class Response:
await self.stream.aclose()
class Cookies(MutableMapping):
class Cookies(typing.MutableMapping[str, str]):
"""
HTTP Cookies, as a mutable mapping.
"""

View File

@ -14,6 +14,16 @@ if typing.TYPE_CHECKING: # pragma: no cover
Event = typing.Union[asyncio.Event, trio.Event]
_Message = typing.Dict[str, typing.Any]
_Receive = typing.Callable[[], typing.Awaitable[_Message]]
_Send = typing.Callable[
[typing.Dict[str, typing.Any]], typing.Coroutine[None, None, None]
]
_ASGIApp = typing.Callable[
[typing.Dict[str, typing.Any], _Receive, _Send], typing.Coroutine[None, None, None]
]
def create_event() -> "Event":
if sniffio.current_async_library() == "trio":
import trio
@ -68,7 +78,7 @@ class ASGITransport(AsyncBaseTransport):
def __init__(
self,
app: typing.Callable,
app: _ASGIApp,
raise_app_exceptions: bool = True,
root_path: str = "",
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
@ -113,7 +123,7 @@ class ASGITransport(AsyncBaseTransport):
# ASGI callables.
async def receive() -> dict:
async def receive() -> typing.Dict[str, typing.Any]:
nonlocal request_complete
if request_complete:
@ -127,7 +137,7 @@ class ASGITransport(AsyncBaseTransport):
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}
async def send(message: dict) -> None:
async def send(message: typing.Dict[str, typing.Any]) -> None:
nonlocal status_code, response_headers, response_started
if message["type"] == "http.response.start":

View File

@ -6,7 +6,7 @@ from .base import AsyncBaseTransport, BaseTransport
class MockTransport(AsyncBaseTransport, BaseTransport):
def __init__(self, handler: typing.Callable) -> None:
def __init__(self, handler: typing.Callable[[Request], Response]) -> None:
self.handler = handler
def handle_request(
@ -29,6 +29,6 @@ class MockTransport(AsyncBaseTransport, BaseTransport):
# https://simonwillison.net/2020/Sep/2/await-me-maybe/
if asyncio.iscoroutine(response):
response = await response
response = await response # type: ignore[func-returns-value,assignment]
return response

View File

@ -8,7 +8,7 @@ from .._types import SyncByteStream
from .base import BaseTransport
def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable:
def _skip_leading_empty_chunks(body: typing.Iterable[bytes]) -> typing.Iterable[bytes]:
body = iter(body)
for chunk in body:
if chunk:
@ -65,7 +65,7 @@ class WSGITransport(BaseTransport):
def __init__(
self,
app: typing.Callable,
app: typing.Callable[..., typing.Any],
raise_app_exceptions: bool = True,
script_name: str = "",
remote_addr: str = "127.0.0.1",
@ -109,7 +109,9 @@ class WSGITransport(BaseTransport):
seen_exc_info = None
def start_response(
status: str, response_headers: list, exc_info: typing.Any = None
status: str,
response_headers: typing.List[typing.Tuple[str, str]],
exc_info: typing.Any = None,
) -> None:
nonlocal seen_status, seen_response_headers, seen_exc_info
seen_status = status

View File

@ -570,7 +570,7 @@ class QueryParams(typing.Mapping[str, str]):
for k, v in dict_value.items()
}
def keys(self) -> typing.KeysView:
def keys(self) -> typing.KeysView[str]:
"""
Return all the keys in the query params.
@ -581,7 +581,7 @@ class QueryParams(typing.Mapping[str, str]):
"""
return self._dict.keys()
def values(self) -> typing.ValuesView:
def values(self) -> typing.ValuesView[str]:
"""
Return all the values in the query params. If a key occurs more than once
only the first item for that key is returned.
@ -593,7 +593,7 @@ class QueryParams(typing.Mapping[str, str]):
"""
return {k: v[0] for k, v in self._dict.items()}.values()
def items(self) -> typing.ItemsView:
def items(self) -> typing.ItemsView[str, str]:
"""
Return all items in the query params. If a key occurs more than once
only the first item for that key is returned.

View File

@ -508,7 +508,7 @@ class URLPattern:
return True
@property
def priority(self) -> tuple:
def priority(self) -> typing.Tuple[int, int, int]:
"""
The priority allows URLPattern instances to be sortable, so that
we can match from most specific to least specific.

View File

@ -4,6 +4,7 @@ max-line-length = 120
[mypy]
disallow_untyped_defs = True
disallow_any_generics = True
ignore_missing_imports = True
no_implicit_optional = True
show_error_codes = True

View File

@ -428,7 +428,7 @@ async def test_digest_auth(
assert response.status_code == 200
assert len(response.history) == 1
authorization = typing.cast(dict, response.json())["auth"]
authorization = typing.cast(typing.Dict[str, typing.Any], response.json())["auth"]
scheme, _, fields = authorization.partition(" ")
assert scheme == "Digest"
@ -459,7 +459,7 @@ async def test_digest_auth_no_specified_qop() -> None:
assert response.status_code == 200
assert len(response.history) == 1
authorization = typing.cast(dict, response.json())["auth"]
authorization = typing.cast(typing.Dict[str, typing.Any], response.json())["auth"]
scheme, _, fields = authorization.partition(" ")
assert scheme == "Digest"