From 1269c006da2ad9c35d812ff309ded7beebc50e81 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 14 Jun 2023 11:27:58 -0700 Subject: [PATCH] PYTHON-3735 Add types to PyMongo auth module (#1231) --- pymongo/auth.py | 87 +++++++++++++++++++++++++++++++------------------ pymongo/pool.py | 4 ++- 2 files changed, 59 insertions(+), 32 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 00b6faa6f..b4d04f8d1 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -13,15 +13,17 @@ # limitations under the License. """Authentication helpers.""" +from __future__ import annotations import functools import hashlib import hmac import os import socket +import typing from base64 import standard_b64decode, standard_b64encode from collections import namedtuple -from typing import Callable, Mapping +from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Optional from urllib.parse import quote from bson.binary import Binary @@ -31,6 +33,10 @@ from pymongo.auth_oidc import _authenticate_oidc, _get_authenticator, _OIDCPrope from pymongo.errors import ConfigurationError, OperationFailure from pymongo.saslprep import saslprep +if TYPE_CHECKING: + from pymongo.hello import Hello + from pymongo.pool import SocketInfo + HAVE_KERBEROS = True _USE_PRINCIPAL = False try: @@ -66,21 +72,21 @@ class _Cache: _hash_val = hash("_Cache") - def __init__(self): + def __init__(self) -> None: self.data = None - def __eq__(self, other): + def __eq__(self, other: object) -> bool: # Two instances must always compare equal. if isinstance(other, _Cache): return True return NotImplemented - def __ne__(self, other): + def __ne__(self, other: object) -> bool: if isinstance(other, _Cache): return False return NotImplemented - def __hash__(self): + def __hash__(self) -> int: return self._hash_val @@ -101,7 +107,14 @@ _AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"]) """Mechanism properties for MONGODB-AWS authentication.""" -def _build_credentials_tuple(mech, source, user, passwd, extra, database): +def _build_credentials_tuple( + mech: str, + source: Optional[str], + user: str, + passwd: str, + extra: Mapping[str, Any], + database: Optional[str], +) -> MongoCredential: """Build and return a mechanism specific credentials tuple.""" if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: raise ConfigurationError(f"{mech} requires a username.") @@ -175,17 +188,21 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): return MongoCredential(mech, source_database, user, passwd, None, _Cache()) -def _xor(fir, sec): +def _xor(fir: bytes, sec: bytes) -> bytes: """XOR two byte strings together (python 3.x).""" return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)]) -def _parse_scram_response(response): +def _parse_scram_response(response: bytes) -> dict: """Split a scram response into key, value pairs.""" - return dict(item.split(b"=", 1) for item in response.split(b",")) + return dict( + typing.cast(typing.Tuple[str, str], item.split(b"=", 1)) for item in response.split(b",") + ) -def _authenticate_scram_start(credentials, mechanism): +def _authenticate_scram_start( + credentials: MongoCredential, mechanism: str +) -> tuple[bytes, bytes, MutableMapping[str, Any]]: username = credentials.username user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C") nonce = standard_b64encode(os.urandom(32)) @@ -203,7 +220,9 @@ def _authenticate_scram_start(credentials, mechanism): return nonce, first_bare, cmd -def _authenticate_scram(credentials, sock_info, mechanism): +def _authenticate_scram( + credentials: MongoCredential, sock_info: SocketInfo, mechanism: str +) -> None: """Authenticate using SCRAM.""" username = credentials.username if mechanism == "SCRAM-SHA-256": @@ -287,7 +306,7 @@ def _authenticate_scram(credentials, sock_info, mechanism): raise OperationFailure("SASL conversation failed to complete.") -def _password_digest(username, password): +def _password_digest(username: str, password: str) -> str: """Get a password digest to use for authentication.""" if not isinstance(password, str): raise TypeError("password must be an instance of str") @@ -302,7 +321,7 @@ def _password_digest(username, password): return md5hash.hexdigest() -def _auth_key(nonce, username, password): +def _auth_key(nonce: str, username: str, password: str) -> str: """Get an auth key to use for authentication.""" digest = _password_digest(username, password) md5hash = hashlib.md5() @@ -311,7 +330,7 @@ def _auth_key(nonce, username, password): return md5hash.hexdigest() -def _canonicalize_hostname(hostname): +def _canonicalize_hostname(hostname: str) -> str: """Canonicalize hostname following MIT-krb5 behavior.""" # https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520 af, socktype, proto, canonname, sockaddr = socket.getaddrinfo( @@ -326,7 +345,7 @@ def _canonicalize_hostname(hostname): return name[0].lower() -def _authenticate_gssapi(credentials, sock_info): +def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> None: """Authenticate using GSSAPI.""" if not HAVE_KERBEROS: raise ConfigurationError( @@ -443,7 +462,7 @@ def _authenticate_gssapi(credentials, sock_info): raise OperationFailure(str(exc)) -def _authenticate_plain(credentials, sock_info): +def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) -> None: """Authenticate using SASL PLAIN (RFC 4616)""" source = credentials.source username = credentials.username @@ -460,7 +479,7 @@ def _authenticate_plain(credentials, sock_info): sock_info.command(source, cmd) -def _authenticate_x509(credentials, sock_info): +def _authenticate_x509(credentials: MongoCredential, sock_info: SocketInfo) -> None: """Authenticate using MONGODB-X509.""" ctx = sock_info.auth_ctx if ctx and ctx.speculate_succeeded(): @@ -471,7 +490,7 @@ def _authenticate_x509(credentials, sock_info): sock_info.command("$external", cmd) -def _authenticate_mongo_cr(credentials, sock_info): +def _authenticate_mongo_cr(credentials: MongoCredential, sock_info: SocketInfo) -> None: """Authenticate using MONGODB-CR.""" source = credentials.source username = credentials.username @@ -486,7 +505,7 @@ def _authenticate_mongo_cr(credentials, sock_info): sock_info.command(source, query) -def _authenticate_default(credentials, sock_info): +def _authenticate_default(credentials: MongoCredential, sock_info: SocketInfo) -> None: if sock_info.max_wire_version >= 7: if sock_info.negotiated_mechs: mechs = sock_info.negotiated_mechs @@ -518,35 +537,39 @@ _AUTH_MAP: Mapping[str, Callable] = { class _AuthContext: - def __init__(self, credentials, address): + def __init__(self, credentials: MongoCredential, address: tuple[str, int]) -> None: self.credentials = credentials - self.speculative_authenticate = None + self.speculative_authenticate: Optional[Mapping[str, Any]] = None self.address = address @staticmethod - def from_credentials(creds, address): + def from_credentials( + creds: MongoCredential, address: tuple[str, int] + ) -> Optional[_AuthContext]: spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism) if spec_cls: return spec_cls(creds, address) return None - def speculate_command(self): + def speculate_command(self) -> Optional[MutableMapping[str, Any]]: raise NotImplementedError - def parse_response(self, hello): + def parse_response(self, hello: Hello) -> None: self.speculative_authenticate = hello.speculative_authenticate - def speculate_succeeded(self): + def speculate_succeeded(self) -> bool: return bool(self.speculative_authenticate) class _ScramContext(_AuthContext): - def __init__(self, credentials, address, mechanism): + def __init__( + self, credentials: MongoCredential, address: tuple[str, int], mechanism: str + ) -> None: super().__init__(credentials, address) - self.scram_data = None + self.scram_data: Optional[tuple[bytes, bytes]] = None self.mechanism = mechanism - def speculate_command(self): + def speculate_command(self) -> Optional[MutableMapping[str, Any]]: nonce, first_bare, cmd = _authenticate_scram_start(self.credentials, self.mechanism) # The 'db' field is included only on the speculative command. cmd["db"] = self.credentials.source @@ -556,7 +579,7 @@ class _ScramContext(_AuthContext): class _X509Context(_AuthContext): - def speculate_command(self): + def speculate_command(self) -> Optional[MutableMapping[str, Any]]: cmd = SON([("authenticate", 1), ("mechanism", "MONGODB-X509")]) if self.credentials.username is not None: cmd["user"] = self.credentials.username @@ -564,7 +587,7 @@ class _X509Context(_AuthContext): class _OIDCContext(_AuthContext): - def speculate_command(self): + def speculate_command(self) -> Optional[MutableMapping[str, Any]]: authenticator = _get_authenticator(self.credentials, self.address) cmd = authenticator.auth_start_cmd(False) if cmd is None: @@ -582,7 +605,9 @@ _SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = { } -def authenticate(credentials, sock_info, reauthenticate=False): +def authenticate( + credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool = False +) -> None: """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] diff --git a/pymongo/pool.py b/pymongo/pool.py index 5bae8ce87..2b498078c 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -758,7 +758,9 @@ class SocketInfo: cmd["saslSupportedMechs"] = creds.source + "." + creds.username auth_ctx = auth._AuthContext.from_credentials(creds, self.address) if auth_ctx: - cmd["speculativeAuthenticate"] = auth_ctx.speculate_command() + speculative_authenticate = auth_ctx.speculate_command() + if speculative_authenticate is not None: + cmd["speculativeAuthenticate"] = speculative_authenticate else: auth_ctx = None