PYTHON-3735 Add types to PyMongo auth module (#1231)

This commit is contained in:
Noah Stapp 2023-06-14 11:27:58 -07:00 committed by GitHub
parent ece45b1edf
commit 1269c006da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 32 deletions

View File

@ -13,15 +13,17 @@
# limitations under the License.
"""Authentication helpers."""
from __future__ import annotations
import functools
import hashlib
import hmac
import os
import socket
import typing
from base64 import standard_b64decode, standard_b64encode
from collections import namedtuple
from typing import Callable, Mapping
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Optional
from urllib.parse import quote
from bson.binary import Binary
@ -31,6 +33,10 @@ from pymongo.auth_oidc import _authenticate_oidc, _get_authenticator, _OIDCPrope
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.saslprep import saslprep
if TYPE_CHECKING:
from pymongo.hello import Hello
from pymongo.pool import SocketInfo
HAVE_KERBEROS = True
_USE_PRINCIPAL = False
try:
@ -66,21 +72,21 @@ class _Cache:
_hash_val = hash("_Cache")
def __init__(self):
def __init__(self) -> None:
self.data = None
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
# Two instances must always compare equal.
if isinstance(other, _Cache):
return True
return NotImplemented
def __ne__(self, other):
def __ne__(self, other: object) -> bool:
if isinstance(other, _Cache):
return False
return NotImplemented
def __hash__(self):
def __hash__(self) -> int:
return self._hash_val
@ -101,7 +107,14 @@ _AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
"""Mechanism properties for MONGODB-AWS authentication."""
def _build_credentials_tuple(mech, source, user, passwd, extra, database):
def _build_credentials_tuple(
mech: str,
source: Optional[str],
user: str,
passwd: str,
extra: Mapping[str, Any],
database: Optional[str],
) -> MongoCredential:
"""Build and return a mechanism specific credentials tuple."""
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
raise ConfigurationError(f"{mech} requires a username.")
@ -175,17 +188,21 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database):
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
def _xor(fir, sec):
def _xor(fir: bytes, sec: bytes) -> bytes:
"""XOR two byte strings together (python 3.x)."""
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
def _parse_scram_response(response):
def _parse_scram_response(response: bytes) -> dict:
"""Split a scram response into key, value pairs."""
return dict(item.split(b"=", 1) for item in response.split(b","))
return dict(
typing.cast(typing.Tuple[str, str], item.split(b"=", 1)) for item in response.split(b",")
)
def _authenticate_scram_start(credentials, mechanism):
def _authenticate_scram_start(
credentials: MongoCredential, mechanism: str
) -> tuple[bytes, bytes, MutableMapping[str, Any]]:
username = credentials.username
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
nonce = standard_b64encode(os.urandom(32))
@ -203,7 +220,9 @@ def _authenticate_scram_start(credentials, mechanism):
return nonce, first_bare, cmd
def _authenticate_scram(credentials, sock_info, mechanism):
def _authenticate_scram(
credentials: MongoCredential, sock_info: SocketInfo, mechanism: str
) -> None:
"""Authenticate using SCRAM."""
username = credentials.username
if mechanism == "SCRAM-SHA-256":
@ -287,7 +306,7 @@ def _authenticate_scram(credentials, sock_info, mechanism):
raise OperationFailure("SASL conversation failed to complete.")
def _password_digest(username, password):
def _password_digest(username: str, password: str) -> str:
"""Get a password digest to use for authentication."""
if not isinstance(password, str):
raise TypeError("password must be an instance of str")
@ -302,7 +321,7 @@ def _password_digest(username, password):
return md5hash.hexdigest()
def _auth_key(nonce, username, password):
def _auth_key(nonce: str, username: str, password: str) -> str:
"""Get an auth key to use for authentication."""
digest = _password_digest(username, password)
md5hash = hashlib.md5()
@ -311,7 +330,7 @@ def _auth_key(nonce, username, password):
return md5hash.hexdigest()
def _canonicalize_hostname(hostname):
def _canonicalize_hostname(hostname: str) -> str:
"""Canonicalize hostname following MIT-krb5 behavior."""
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
@ -326,7 +345,7 @@ def _canonicalize_hostname(hostname):
return name[0].lower()
def _authenticate_gssapi(credentials, sock_info):
def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> None:
"""Authenticate using GSSAPI."""
if not HAVE_KERBEROS:
raise ConfigurationError(
@ -443,7 +462,7 @@ def _authenticate_gssapi(credentials, sock_info):
raise OperationFailure(str(exc))
def _authenticate_plain(credentials, sock_info):
def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) -> None:
"""Authenticate using SASL PLAIN (RFC 4616)"""
source = credentials.source
username = credentials.username
@ -460,7 +479,7 @@ def _authenticate_plain(credentials, sock_info):
sock_info.command(source, cmd)
def _authenticate_x509(credentials, sock_info):
def _authenticate_x509(credentials: MongoCredential, sock_info: SocketInfo) -> None:
"""Authenticate using MONGODB-X509."""
ctx = sock_info.auth_ctx
if ctx and ctx.speculate_succeeded():
@ -471,7 +490,7 @@ def _authenticate_x509(credentials, sock_info):
sock_info.command("$external", cmd)
def _authenticate_mongo_cr(credentials, sock_info):
def _authenticate_mongo_cr(credentials: MongoCredential, sock_info: SocketInfo) -> None:
"""Authenticate using MONGODB-CR."""
source = credentials.source
username = credentials.username
@ -486,7 +505,7 @@ def _authenticate_mongo_cr(credentials, sock_info):
sock_info.command(source, query)
def _authenticate_default(credentials, sock_info):
def _authenticate_default(credentials: MongoCredential, sock_info: SocketInfo) -> None:
if sock_info.max_wire_version >= 7:
if sock_info.negotiated_mechs:
mechs = sock_info.negotiated_mechs
@ -518,35 +537,39 @@ _AUTH_MAP: Mapping[str, Callable] = {
class _AuthContext:
def __init__(self, credentials, address):
def __init__(self, credentials: MongoCredential, address: tuple[str, int]) -> None:
self.credentials = credentials
self.speculative_authenticate = None
self.speculative_authenticate: Optional[Mapping[str, Any]] = None
self.address = address
@staticmethod
def from_credentials(creds, address):
def from_credentials(
creds: MongoCredential, address: tuple[str, int]
) -> Optional[_AuthContext]:
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
if spec_cls:
return spec_cls(creds, address)
return None
def speculate_command(self):
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
raise NotImplementedError
def parse_response(self, hello):
def parse_response(self, hello: Hello) -> None:
self.speculative_authenticate = hello.speculative_authenticate
def speculate_succeeded(self):
def speculate_succeeded(self) -> bool:
return bool(self.speculative_authenticate)
class _ScramContext(_AuthContext):
def __init__(self, credentials, address, mechanism):
def __init__(
self, credentials: MongoCredential, address: tuple[str, int], mechanism: str
) -> None:
super().__init__(credentials, address)
self.scram_data = None
self.scram_data: Optional[tuple[bytes, bytes]] = None
self.mechanism = mechanism
def speculate_command(self):
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
nonce, first_bare, cmd = _authenticate_scram_start(self.credentials, self.mechanism)
# The 'db' field is included only on the speculative command.
cmd["db"] = self.credentials.source
@ -556,7 +579,7 @@ class _ScramContext(_AuthContext):
class _X509Context(_AuthContext):
def speculate_command(self):
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
cmd = SON([("authenticate", 1), ("mechanism", "MONGODB-X509")])
if self.credentials.username is not None:
cmd["user"] = self.credentials.username
@ -564,7 +587,7 @@ class _X509Context(_AuthContext):
class _OIDCContext(_AuthContext):
def speculate_command(self):
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
authenticator = _get_authenticator(self.credentials, self.address)
cmd = authenticator.auth_start_cmd(False)
if cmd is None:
@ -582,7 +605,9 @@ _SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = {
}
def authenticate(credentials, sock_info, reauthenticate=False):
def authenticate(
credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool = False
) -> None:
"""Authenticate sock_info."""
mechanism = credentials.mechanism
auth_func = _AUTH_MAP[mechanism]

View File

@ -758,7 +758,9 @@ class SocketInfo:
cmd["saslSupportedMechs"] = creds.source + "." + creds.username
auth_ctx = auth._AuthContext.from_credentials(creds, self.address)
if auth_ctx:
cmd["speculativeAuthenticate"] = auth_ctx.speculate_command()
speculative_authenticate = auth_ctx.speculate_command()
if speculative_authenticate is not None:
cmd["speculativeAuthenticate"] = speculative_authenticate
else:
auth_ctx = None