diff --git a/doc/contributors.rst b/doc/contributors.rst index 2a4ca1ea4..b6e143440 100644 --- a/doc/contributors.rst +++ b/doc/contributors.rst @@ -98,3 +98,4 @@ The following is a list of people who have contributed to - Dainis Gorbunovs (DainisGorbunovs) - Iris Ho (sleepyStick) - Stephan Hof (stephan-hof) +- Casey Clements (caseyclements) diff --git a/pymongo/common.py b/pymongo/common.py index bda294af9..41d1e1050 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -20,6 +20,7 @@ import datetime import inspect import warnings from collections import OrderedDict, abc +from difflib import get_close_matches from typing import ( TYPE_CHECKING, Any, @@ -162,9 +163,12 @@ def clean_node(node: str) -> tuple[str, int]: return host.lower(), port -def raise_config_error(key: str, dummy: Any) -> NoReturn: +def raise_config_error(key: str, suggestions: Optional[list] = None) -> NoReturn: """Raise ConfigurationError with the given key name.""" - raise ConfigurationError(f"Unknown option {key}") + msg = f"Unknown option: {key}." + if suggestions: + msg += f" Did you mean one of ({', '.join(suggestions)}) or maybe a camelCase version of one? Refer to docstring." + raise ConfigurationError(msg) # Mapping of URI uuid representation options to valid subtypes. @@ -810,14 +814,24 @@ def validate_auth_option(option: str, value: Any) -> tuple[str, Any]: """Validate optional authentication parameters.""" lower, value = validate(option, value) if lower not in _AUTH_OPTIONS: - raise ConfigurationError(f"Unknown authentication option: {option}") + raise ConfigurationError(f"Unknown option: {option}. Must be in {_AUTH_OPTIONS}") return option, value +def _get_validator( + key: str, validators: dict[str, Callable[[Any, Any], Any]], normed_key: Optional[str] = None +) -> Callable: + normed_key = normed_key or key + try: + return validators[normed_key] + except KeyError: + suggestions = get_close_matches(normed_key, validators, cutoff=0.2) + raise_config_error(key, suggestions) + + def validate(option: str, value: Any) -> tuple[str, Any]: """Generic validation function.""" - lower = option.lower() - validator = VALIDATORS.get(lower, raise_config_error) + validator = _get_validator(option, VALIDATORS, normed_key=option.lower()) value = validator(option, value) return option, value @@ -855,15 +869,15 @@ def get_validated_options( for opt, value in options.items(): normed_key = get_normed_key(opt) try: - validator = URI_OPTIONS_VALIDATOR_MAP.get(normed_key, raise_config_error) - value = validator(opt, value) # noqa: PLW2901 + validator = _get_validator(opt, URI_OPTIONS_VALIDATOR_MAP, normed_key=normed_key) + validated = validator(opt, value) except (ValueError, TypeError, ConfigurationError) as exc: if warn: warnings.warn(str(exc), stacklevel=2) else: raise else: - validated_options[get_setter_key(normed_key)] = value + validated_options[get_setter_key(normed_key)] = validated return validated_options diff --git a/test/test_client.py b/test/test_client.py index aceb15312..089b1673b 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -21,6 +21,7 @@ import copy import datetime import gc import os +import re import signal import socket import struct @@ -535,6 +536,14 @@ class ClientUnitTest(unittest.TestCase): self.assertIsInstance(c.options.retry_writes, bool) self.assertIsInstance(c.options.retry_reads, bool) + def test_validate_suggestion(self): + """Validate kwargs in constructor.""" + for typo in ["auth", "Auth", "AUTH"]: + expected = f"Unknown option: {typo}. Did you mean one of (authsource, authmechanism, authoidcallowedhosts) or maybe a camelCase version of one? Refer to docstring." + expected = re.escape(expected) + with self.assertRaisesRegex(ConfigurationError, expected): + MongoClient(**{typo: "standard"}) # type: ignore[arg-type] + class TestClient(IntegrationTest): def test_multiple_uris(self): diff --git a/test/test_uri_parser.py b/test/test_uri_parser.py index d5a25f590..a4ad908e1 100644 --- a/test/test_uri_parser.py +++ b/test/test_uri_parser.py @@ -144,6 +144,12 @@ class TestURI(unittest.TestCase): self.assertEqual({"authsource": "foobar"}, split_options("authSource=foobar")) self.assertEqual({"maxpoolsize": 50}, split_options("maxpoolsize=50")) + # Test suggestions given when invalid kwarg passed + + expected = r"Unknown option: auth. Did you mean one of \(authsource, authmechanism, timeoutms\) or maybe a camelCase version of one\? Refer to docstring." + with self.assertRaisesRegex(ConfigurationError, expected): + split_options("auth=GSSAPI") + def test_parse_uri(self): self.assertRaises(InvalidURI, parse_uri, "http://foobar.com") self.assertRaises(InvalidURI, parse_uri, "http://foo@foobar.com")