PYTHON-3735 Add types to PyMongo auth module (#1231)
This commit is contained in:
parent
ece45b1edf
commit
1269c006da
@ -13,15 +13,17 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import socket
|
||||
import typing
|
||||
from base64 import standard_b64decode, standard_b64encode
|
||||
from collections import namedtuple
|
||||
from typing import Callable, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from bson.binary import Binary
|
||||
@ -31,6 +33,10 @@ from pymongo.auth_oidc import _authenticate_oidc, _get_authenticator, _OIDCPrope
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.saslprep import saslprep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.pool import SocketInfo
|
||||
|
||||
HAVE_KERBEROS = True
|
||||
_USE_PRINCIPAL = False
|
||||
try:
|
||||
@ -66,21 +72,21 @@ class _Cache:
|
||||
|
||||
_hash_val = hash("_Cache")
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.data = None
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
# Two instances must always compare equal.
|
||||
if isinstance(other, _Cache):
|
||||
return True
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: object) -> bool:
|
||||
if isinstance(other, _Cache):
|
||||
return False
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return self._hash_val
|
||||
|
||||
|
||||
@ -101,7 +107,14 @@ _AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
|
||||
"""Mechanism properties for MONGODB-AWS authentication."""
|
||||
|
||||
|
||||
def _build_credentials_tuple(mech, source, user, passwd, extra, database):
|
||||
def _build_credentials_tuple(
|
||||
mech: str,
|
||||
source: Optional[str],
|
||||
user: str,
|
||||
passwd: str,
|
||||
extra: Mapping[str, Any],
|
||||
database: Optional[str],
|
||||
) -> MongoCredential:
|
||||
"""Build and return a mechanism specific credentials tuple."""
|
||||
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
|
||||
raise ConfigurationError(f"{mech} requires a username.")
|
||||
@ -175,17 +188,21 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database):
|
||||
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
|
||||
|
||||
|
||||
def _xor(fir, sec):
|
||||
def _xor(fir: bytes, sec: bytes) -> bytes:
|
||||
"""XOR two byte strings together (python 3.x)."""
|
||||
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
|
||||
|
||||
|
||||
def _parse_scram_response(response):
|
||||
def _parse_scram_response(response: bytes) -> dict:
|
||||
"""Split a scram response into key, value pairs."""
|
||||
return dict(item.split(b"=", 1) for item in response.split(b","))
|
||||
return dict(
|
||||
typing.cast(typing.Tuple[str, str], item.split(b"=", 1)) for item in response.split(b",")
|
||||
)
|
||||
|
||||
|
||||
def _authenticate_scram_start(credentials, mechanism):
|
||||
def _authenticate_scram_start(
|
||||
credentials: MongoCredential, mechanism: str
|
||||
) -> tuple[bytes, bytes, MutableMapping[str, Any]]:
|
||||
username = credentials.username
|
||||
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
|
||||
nonce = standard_b64encode(os.urandom(32))
|
||||
@ -203,7 +220,9 @@ def _authenticate_scram_start(credentials, mechanism):
|
||||
return nonce, first_bare, cmd
|
||||
|
||||
|
||||
def _authenticate_scram(credentials, sock_info, mechanism):
|
||||
def _authenticate_scram(
|
||||
credentials: MongoCredential, sock_info: SocketInfo, mechanism: str
|
||||
) -> None:
|
||||
"""Authenticate using SCRAM."""
|
||||
username = credentials.username
|
||||
if mechanism == "SCRAM-SHA-256":
|
||||
@ -287,7 +306,7 @@ def _authenticate_scram(credentials, sock_info, mechanism):
|
||||
raise OperationFailure("SASL conversation failed to complete.")
|
||||
|
||||
|
||||
def _password_digest(username, password):
|
||||
def _password_digest(username: str, password: str) -> str:
|
||||
"""Get a password digest to use for authentication."""
|
||||
if not isinstance(password, str):
|
||||
raise TypeError("password must be an instance of str")
|
||||
@ -302,7 +321,7 @@ def _password_digest(username, password):
|
||||
return md5hash.hexdigest()
|
||||
|
||||
|
||||
def _auth_key(nonce, username, password):
|
||||
def _auth_key(nonce: str, username: str, password: str) -> str:
|
||||
"""Get an auth key to use for authentication."""
|
||||
digest = _password_digest(username, password)
|
||||
md5hash = hashlib.md5()
|
||||
@ -311,7 +330,7 @@ def _auth_key(nonce, username, password):
|
||||
return md5hash.hexdigest()
|
||||
|
||||
|
||||
def _canonicalize_hostname(hostname):
|
||||
def _canonicalize_hostname(hostname: str) -> str:
|
||||
"""Canonicalize hostname following MIT-krb5 behavior."""
|
||||
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
|
||||
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
|
||||
@ -326,7 +345,7 @@ def _canonicalize_hostname(hostname):
|
||||
return name[0].lower()
|
||||
|
||||
|
||||
def _authenticate_gssapi(credentials, sock_info):
|
||||
def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
"""Authenticate using GSSAPI."""
|
||||
if not HAVE_KERBEROS:
|
||||
raise ConfigurationError(
|
||||
@ -443,7 +462,7 @@ def _authenticate_gssapi(credentials, sock_info):
|
||||
raise OperationFailure(str(exc))
|
||||
|
||||
|
||||
def _authenticate_plain(credentials, sock_info):
|
||||
def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
"""Authenticate using SASL PLAIN (RFC 4616)"""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
@ -460,7 +479,7 @@ def _authenticate_plain(credentials, sock_info):
|
||||
sock_info.command(source, cmd)
|
||||
|
||||
|
||||
def _authenticate_x509(credentials, sock_info):
|
||||
def _authenticate_x509(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
"""Authenticate using MONGODB-X509."""
|
||||
ctx = sock_info.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
@ -471,7 +490,7 @@ def _authenticate_x509(credentials, sock_info):
|
||||
sock_info.command("$external", cmd)
|
||||
|
||||
|
||||
def _authenticate_mongo_cr(credentials, sock_info):
|
||||
def _authenticate_mongo_cr(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
"""Authenticate using MONGODB-CR."""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
@ -486,7 +505,7 @@ def _authenticate_mongo_cr(credentials, sock_info):
|
||||
sock_info.command(source, query)
|
||||
|
||||
|
||||
def _authenticate_default(credentials, sock_info):
|
||||
def _authenticate_default(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
if sock_info.max_wire_version >= 7:
|
||||
if sock_info.negotiated_mechs:
|
||||
mechs = sock_info.negotiated_mechs
|
||||
@ -518,35 +537,39 @@ _AUTH_MAP: Mapping[str, Callable] = {
|
||||
|
||||
|
||||
class _AuthContext:
|
||||
def __init__(self, credentials, address):
|
||||
def __init__(self, credentials: MongoCredential, address: tuple[str, int]) -> None:
|
||||
self.credentials = credentials
|
||||
self.speculative_authenticate = None
|
||||
self.speculative_authenticate: Optional[Mapping[str, Any]] = None
|
||||
self.address = address
|
||||
|
||||
@staticmethod
|
||||
def from_credentials(creds, address):
|
||||
def from_credentials(
|
||||
creds: MongoCredential, address: tuple[str, int]
|
||||
) -> Optional[_AuthContext]:
|
||||
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
|
||||
if spec_cls:
|
||||
return spec_cls(creds, address)
|
||||
return None
|
||||
|
||||
def speculate_command(self):
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_response(self, hello):
|
||||
def parse_response(self, hello: Hello) -> None:
|
||||
self.speculative_authenticate = hello.speculative_authenticate
|
||||
|
||||
def speculate_succeeded(self):
|
||||
def speculate_succeeded(self) -> bool:
|
||||
return bool(self.speculative_authenticate)
|
||||
|
||||
|
||||
class _ScramContext(_AuthContext):
|
||||
def __init__(self, credentials, address, mechanism):
|
||||
def __init__(
|
||||
self, credentials: MongoCredential, address: tuple[str, int], mechanism: str
|
||||
) -> None:
|
||||
super().__init__(credentials, address)
|
||||
self.scram_data = None
|
||||
self.scram_data: Optional[tuple[bytes, bytes]] = None
|
||||
self.mechanism = mechanism
|
||||
|
||||
def speculate_command(self):
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
nonce, first_bare, cmd = _authenticate_scram_start(self.credentials, self.mechanism)
|
||||
# The 'db' field is included only on the speculative command.
|
||||
cmd["db"] = self.credentials.source
|
||||
@ -556,7 +579,7 @@ class _ScramContext(_AuthContext):
|
||||
|
||||
|
||||
class _X509Context(_AuthContext):
|
||||
def speculate_command(self):
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
cmd = SON([("authenticate", 1), ("mechanism", "MONGODB-X509")])
|
||||
if self.credentials.username is not None:
|
||||
cmd["user"] = self.credentials.username
|
||||
@ -564,7 +587,7 @@ class _X509Context(_AuthContext):
|
||||
|
||||
|
||||
class _OIDCContext(_AuthContext):
|
||||
def speculate_command(self):
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
authenticator = _get_authenticator(self.credentials, self.address)
|
||||
cmd = authenticator.auth_start_cmd(False)
|
||||
if cmd is None:
|
||||
@ -582,7 +605,9 @@ _SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = {
|
||||
}
|
||||
|
||||
|
||||
def authenticate(credentials, sock_info, reauthenticate=False):
|
||||
def authenticate(
|
||||
credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool = False
|
||||
) -> None:
|
||||
"""Authenticate sock_info."""
|
||||
mechanism = credentials.mechanism
|
||||
auth_func = _AUTH_MAP[mechanism]
|
||||
|
||||
@ -758,7 +758,9 @@ class SocketInfo:
|
||||
cmd["saslSupportedMechs"] = creds.source + "." + creds.username
|
||||
auth_ctx = auth._AuthContext.from_credentials(creds, self.address)
|
||||
if auth_ctx:
|
||||
cmd["speculativeAuthenticate"] = auth_ctx.speculate_command()
|
||||
speculative_authenticate = auth_ctx.speculate_command()
|
||||
if speculative_authenticate is not None:
|
||||
cmd["speculativeAuthenticate"] = speculative_authenticate
|
||||
else:
|
||||
auth_ctx = None
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user