Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2024-11-11 09:34:48 -06:00
commit 82978f9fa3
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
9 changed files with 381 additions and 324 deletions

File diff suppressed because it is too large Load Diff

View File

@ -46,6 +46,11 @@ export PROJECT="$project"
export PIP_QUIET=1
EOT
# Skip CSOT tests on non-linux platforms.
if [ "$(uname -s)" != "Linux" ]; then
echo "export SKIP_CSOT_TESTS=1" >> $SCRIPT_DIR/env.sh
fi
# Add these expansions to make it easier to call out tests scripts from the EVG yaml
cat <<EOT > expansion.yml
DRIVERS_TOOLS: "$DRIVERS_TOOLS"

View File

@ -57,17 +57,26 @@ class Host:
name: str
run_on: str
display_name: str
variables: dict[str, str] | None
# Hosts with toolchains.
HOSTS["rhel8"] = Host("rhel8", "rhel87-small", "RHEL8")
HOSTS["win64"] = Host("win64", "windows-64-vsMulti-small", "Win64")
HOSTS["win32"] = Host("win32", "windows-64-vsMulti-small", "Win32")
HOSTS["macos"] = Host("macos", "macos-14", "macOS")
HOSTS["macos-arm64"] = Host("macos-arm64", "macos-14-arm64", "macOS Arm64")
HOSTS["ubuntu20"] = Host("ubuntu20", "ubuntu2004-small", "Ubuntu-20")
HOSTS["ubuntu22"] = Host("ubuntu22", "ubuntu2204-small", "Ubuntu-22")
HOSTS["rhel7"] = Host("rhel7", "rhel79-small", "RHEL7")
HOSTS["rhel8"] = Host("rhel8", "rhel87-small", "RHEL8", dict())
HOSTS["win64"] = Host("win64", "windows-64-vsMulti-small", "Win64", dict())
HOSTS["win32"] = Host("win32", "windows-64-vsMulti-small", "Win32", dict())
HOSTS["macos"] = Host("macos", "macos-14", "macOS", dict())
HOSTS["macos-arm64"] = Host("macos-arm64", "macos-14-arm64", "macOS Arm64", dict())
HOSTS["ubuntu20"] = Host("ubuntu20", "ubuntu2004-small", "Ubuntu-20", dict())
HOSTS["ubuntu22"] = Host("ubuntu22", "ubuntu2204-small", "Ubuntu-22", dict())
HOSTS["rhel7"] = Host("rhel7", "rhel79-small", "RHEL7", dict())
DEFAULT_HOST = HOSTS["rhel8"]
# Other hosts
OTHER_HOSTS = ["RHEL9-FIPS", "RHEL8-zseries", "RHEL8-POWER8", "RHEL8-arm64"]
for name, run_on in zip(
OTHER_HOSTS, ["rhel92-fips", "rhel8-zseries-small", "rhel8-power-small", "rhel82-arm64-small"]
):
HOSTS[name] = Host(name, run_on, name, dict())
##############
@ -75,57 +84,77 @@ HOSTS["rhel7"] = Host("rhel7", "rhel79-small", "RHEL7")
##############
def create_variant(
def create_variant_generic(
task_names: list[str],
display_name: str,
*,
python: str | None = None,
version: str | None = None,
host: str | None = None,
host: Host | None = None,
default_run_on="rhel87-small",
expansions: dict | None = None,
**kwargs: Any,
) -> BuildVariant:
"""Create a build variant for the given inputs."""
task_refs = [EvgTaskRef(name=n) for n in task_names]
kwargs.setdefault("expansions", dict())
expansions = kwargs.pop("expansions", dict()).copy()
expansions = expansions and expansions.copy() or dict()
if "run_on" in kwargs:
run_on = kwargs.pop("run_on")
elif host:
run_on = [host.run_on]
if host.variables:
expansions.update(host.variables)
else:
host = host or "rhel8"
run_on = [HOSTS[host].run_on]
run_on = [default_run_on]
if isinstance(run_on, str):
run_on = [run_on]
name = display_name.replace(" ", "-").replace("*-", "").lower()
if python:
expansions["PYTHON_BINARY"] = get_python_binary(python, host)
if version:
expansions["VERSION"] = version
expansions = expansions or None
return BuildVariant(
name=name,
display_name=display_name,
tasks=task_refs,
expansions=expansions,
expansions=expansions or None,
run_on=run_on,
**kwargs,
)
def get_python_binary(python: str, host: str) -> str:
def create_variant(
task_names: list[str],
display_name: str,
*,
version: str | None = None,
host: Host | None = None,
python: str | None = None,
expansions: dict | None = None,
**kwargs: Any,
) -> BuildVariant:
expansions = expansions and expansions.copy() or dict()
if version:
expansions["VERSION"] = version
if python:
expansions["PYTHON_BINARY"] = get_python_binary(python, host)
return create_variant_generic(
task_names, display_name, version=version, host=host, expansions=expansions, **kwargs
)
def get_python_binary(python: str, host: Host) -> str:
"""Get the appropriate python binary given a python version and host."""
if host in ["win64", "win32"]:
if host == "win32":
name = host.name
if name in ["win64", "win32"]:
if name == "win32":
base = "C:/python/32"
else:
base = "C:/python"
python = python.replace(".", "")
return f"{base}/Python{python}/python.exe"
if host in ["rhel8", "ubuntu22", "ubuntu20", "rhel7"]:
if name in ["rhel8", "ubuntu22", "ubuntu20", "rhel7"]:
return f"/opt/python/{python}/bin/python3"
if host in ["macos", "macos-arm64"]:
if name in ["macos", "macos-arm64"]:
return f"/Library/Frameworks/Python.Framework/Versions/{python}/bin/python3"
raise ValueError(f"no match found for python {python} on {host}")
raise ValueError(f"no match found for python {python} on {name}")
def get_versions_from(min_version: str) -> list[str]:
@ -146,11 +175,11 @@ def get_versions_until(max_version: str) -> list[str]:
return versions
def get_display_name(base: str, host: str | None = None, **kwargs) -> str:
def get_display_name(base: str, host: Host | None = None, **kwargs) -> str:
"""Get the display name of a variant."""
display_name = base
if host is not None:
display_name += f" {HOSTS[host].display_name}"
display_name += f" {host.display_name}"
version = kwargs.pop("VERSION", None)
version = version or kwargs.pop("version", None)
if version:
@ -161,7 +190,9 @@ def get_display_name(base: str, host: str | None = None, **kwargs) -> str:
name = value
if key.lower() == "python":
if not value.startswith("pypy"):
name = f"py{value}"
name = f"Python{value}"
else:
name = f"PyPy{value.replace('pypy', '')}"
elif key.lower() in DISPLAY_LOOKUP:
name = DISPLAY_LOOKUP[key.lower()][value]
else:
@ -203,10 +234,10 @@ def create_ocsp_variants() -> list[BuildVariant]:
expansions = dict(AUTH="noauth", SSL="ssl", TOPOLOGY="server")
base_display = "OCSP"
# OCSP tests on rhel8 with all servers v4.4+ and all python versions.
# OCSP tests on default host with all servers v4.4+ and all python versions.
versions = [v for v in ALL_VERSIONS if v != "4.0"]
for version, python in zip_cycle(versions, ALL_PYTHONS):
host = "rhel8"
host = DEFAULT_HOST
variant = create_variant(
[".ocsp"],
get_display_name(base_display, host, version=version, python=python),
@ -220,7 +251,8 @@ def create_ocsp_variants() -> list[BuildVariant]:
# OCSP tests on Windows and MacOS.
# MongoDB servers on these hosts do not staple OCSP responses and only support RSA.
for host, version in product(["win64", "macos"], ["4.4", "8.0"]):
for host_name, version in product(["win64", "macos"], ["4.4", "8.0"]):
host = HOSTS[host_name]
python = CPYTHONS[0] if version == "4.4" else CPYTHONS[-1]
variant = create_variant(
[".ocsp-rsa !.ocsp-staple"],
@ -240,7 +272,7 @@ def create_server_variants() -> list[BuildVariant]:
variants = []
# Run the full matrix on linux with min and max CPython, and latest pypy.
host = "rhel8"
host = DEFAULT_HOST
# Prefix the display name with an asterisk so it is sorted first.
base_display_name = "* Test"
for python in [*MIN_MAX_PYTHON, PYPYS[-1]]:
@ -270,23 +302,17 @@ def create_server_variants() -> list[BuildVariant]:
variants.append(variant)
# Test a subset on each of the other platforms.
for host in ("macos", "macos-arm64", "win64", "win32"):
for host_name in ("macos", "macos-arm64", "win64", "win32"):
for python in MIN_MAX_PYTHON:
tasks = [f"{t} !.sync_async" for t in SUB_TASKS]
# MacOS arm64 only works on server versions 6.0+
if host == "macos-arm64":
if host_name == "macos-arm64":
tasks = []
for version in get_versions_from("6.0"):
tasks.extend(f"{t} .{version} !.sync_async" for t in SUB_TASKS)
expansions = dict(SKIP_CSOT_TESTS="true")
display_name = get_display_name(base_display_name, host, python=python, **expansions)
variant = create_variant(
tasks,
display_name,
python=python,
host=host,
expansions=expansions,
)
host = HOSTS[host_name]
display_name = get_display_name(base_display_name, host, python=python)
variant = create_variant(tasks, display_name, python=python, host=host)
variants.append(variant)
return variants
@ -305,7 +331,7 @@ def create_encryption_variants() -> list[BuildVariant]:
expansions["test_encryption_pyopenssl"] = "true"
return expansions
host = "rhel8"
host = DEFAULT_HOST
# Test against all server versions for the three main python versions.
encryptions = ["Encryption", "Encryption crypt_shared", "Encryption PyOpenSSL"]
@ -339,7 +365,8 @@ def create_encryption_variants() -> list[BuildVariant]:
# Test on macos and linux on one server version and topology for min and max python.
encryptions = ["Encryption", "Encryption crypt_shared"]
task_names = [".latest .replica_set .sync_async"]
for host, encryption, python in product(["macos", "win64"], encryptions, MIN_MAX_PYTHON):
for host_name, encryption, python in product(["macos", "win64"], encryptions, MIN_MAX_PYTHON):
host = HOSTS[host_name]
expansions = get_encryption_expansions(encryption)
display_name = get_display_name(encryption, host, python=python, **expansions)
variant = create_variant(
@ -357,7 +384,7 @@ def create_encryption_variants() -> list[BuildVariant]:
def create_load_balancer_variants():
# Load balancer tests - run all supported server versions using the lowest supported python.
host = "rhel8"
host = DEFAULT_HOST
batchtime = BATCHTIME_WEEK
versions = get_versions_from("6.0")
variants = []
@ -379,7 +406,7 @@ def create_load_balancer_variants():
def create_compression_variants():
# Compression tests - standalone versions of each server, across python versions, with and without c extensions.
# PyPy interpreters are always tested without extensions.
host = "rhel8"
host = DEFAULT_HOST
base_task = ".standalone .noauth .nossl .sync_async"
task_names = dict(snappy=[base_task], zlib=[base_task], zstd=[f"{base_task} !.4.0"])
variants = []
@ -423,11 +450,11 @@ def create_enterprise_auth_variants():
# All python versions across platforms.
for python in ALL_PYTHONS:
if python == CPYTHONS[0]:
host = "macos"
host = HOSTS["macos"]
elif python == CPYTHONS[-1]:
host = "win64"
host = HOSTS["win64"]
else:
host = "rhel8"
host = DEFAULT_HOST
display_name = get_display_name("Auth Enterprise", host, python=python, **expansions)
variant = create_variant(
["test-enterprise-auth"], display_name, host=host, python=python, expansions=expansions
@ -448,11 +475,11 @@ def create_pyopenssl_variants():
auth = "noauth" if python == CPYTHONS[0] else "auth"
ssl = "nossl" if auth == "noauth" else "ssl"
if python == CPYTHONS[0]:
host = "macos"
host = HOSTS["macos"]
elif python == CPYTHONS[-1]:
host = "win64"
host = HOSTS["win64"]
else:
host = "rhel8"
host = DEFAULT_HOST
display_name = get_display_name(base_name, host, python=python)
variant = create_variant(
@ -469,7 +496,7 @@ def create_pyopenssl_variants():
def create_storage_engine_variants():
host = "rhel8"
host = DEFAULT_HOST
engines = ["InMemory", "MMAPv1"]
variants = []
for engine in engines:
@ -492,7 +519,7 @@ def create_storage_engine_variants():
def create_stable_api_variants():
host = "rhel8"
host = DEFAULT_HOST
tags = ["versionedApi_tag"]
tasks = [f".standalone .{v} .noauth .nossl .sync_async" for v in get_versions_from("5.0")]
variants = []
@ -526,7 +553,7 @@ def create_stable_api_variants():
def create_green_framework_variants():
variants = []
tasks = [".standalone .noauth .nossl .sync_async"]
host = "rhel8"
host = DEFAULT_HOST
for python, framework in product([CPYTHONS[0], CPYTHONS[-2]], ["eventlet", "gevent"]):
expansions = dict(GREEN_FRAMEWORK=framework, AUTH="auth", SSL="ssl")
display_name = get_display_name(f"Green {framework.capitalize()}", host, python=python)
@ -539,7 +566,7 @@ def create_green_framework_variants():
def create_no_c_ext_variants():
variants = []
host = "rhel8"
host = DEFAULT_HOST
for python, topology in zip_cycle(CPYTHONS, TOPOLOGIES):
tasks = [f".{topology} .noauth .nossl .sync_async"]
expansions = dict()
@ -554,7 +581,7 @@ def create_no_c_ext_variants():
def create_atlas_data_lake_variants():
variants = []
host = "ubuntu22"
host = HOSTS["ubuntu22"]
for python, c_ext in product(MIN_MAX_PYTHON, C_EXTS):
tasks = ["atlas-data-lake-tests"]
expansions = dict(AUTH="auth")
@ -569,7 +596,7 @@ def create_atlas_data_lake_variants():
def create_mod_wsgi_variants():
variants = []
host = "ubuntu22"
host = HOSTS["ubuntu22"]
tasks = [
"mod-wsgi-standalone",
"mod-wsgi-replica-set",
@ -587,7 +614,7 @@ def create_mod_wsgi_variants():
def create_disable_test_commands_variants():
host = "rhel8"
host = DEFAULT_HOST
expansions = dict(AUTH="auth", SSL="ssl", DISABLE_TEST_COMMANDS="1")
python = CPYTHONS[0]
display_name = get_display_name("Disable test commands", host, python=python)
@ -596,7 +623,7 @@ def create_disable_test_commands_variants():
def create_serverless_variants():
host = "rhel8"
host = DEFAULT_HOST
batchtime = BATCHTIME_WEEK
expansions = dict(test_serverless="true", AUTH="auth", SSL="ssl")
tasks = ["serverless_task_group"]
@ -617,10 +644,11 @@ def create_serverless_variants():
def create_oidc_auth_variants():
variants = []
other_tasks = ["testazureoidc_task_group", "testgcpoidc_task_group", "testk8soidc_task_group"]
for host in ["ubuntu22", "macos", "win64"]:
for host_name in ["ubuntu22", "macos", "win64"]:
tasks = ["testoidc_task_group"]
if host == "ubuntu22":
if host_name == "ubuntu22":
tasks += other_tasks
host = HOSTS[host_name]
variants.append(
create_variant(
tasks,
@ -633,7 +661,7 @@ def create_oidc_auth_variants():
def create_search_index_variants():
host = "rhel8"
host = DEFAULT_HOST
python = CPYTHONS[0]
return [
create_variant(
@ -646,7 +674,7 @@ def create_search_index_variants():
def create_mockupdb_variants():
host = "rhel8"
host = DEFAULT_HOST
python = CPYTHONS[0]
return [
create_variant(
@ -659,7 +687,7 @@ def create_mockupdb_variants():
def create_doctests_variants():
host = "rhel8"
host = DEFAULT_HOST
python = CPYTHONS[0]
return [
create_variant(
@ -672,7 +700,7 @@ def create_doctests_variants():
def create_atlas_connect_variants():
host = "rhel8"
host = DEFAULT_HOST
return [
create_variant(
["atlas-connect"],
@ -696,13 +724,14 @@ def create_aws_auth_variants():
"aws-auth-test-latest",
]
for host, python in product(["ubuntu20", "win64", "macos"], MIN_MAX_PYTHON):
for host_name, python in product(["ubuntu20", "win64", "macos"], MIN_MAX_PYTHON):
expansions = dict()
if host != "ubuntu20":
if host_name != "ubuntu20":
expansions["skip_ECS_auth_test"] = "true"
if host == "macos":
if host_name == "macos":
expansions["skip_EC2_auth_test"] = "true"
expansions["skip_web_identity_auth_test"] = "true"
host = HOSTS[host_name]
variant = create_variant(
tasks,
get_display_name("Auth AWS", host, python=python),
@ -719,11 +748,11 @@ def create_alternative_hosts_variants():
batchtime = BATCHTIME_WEEK
variants = []
host = "rhel7"
host = HOSTS["rhel7"]
variants.append(
create_variant(
[".5.0 .standalone !.sync_async"],
get_display_name("OpenSSL 1.0.2", "rhel7", python=CPYTHONS[0], **expansions),
get_display_name("OpenSSL 1.0.2", host, python=CPYTHONS[0], **expansions),
host=host,
python=CPYTHONS[0],
batchtime=batchtime,
@ -731,16 +760,15 @@ def create_alternative_hosts_variants():
)
)
hosts = ["rhel92-fips", "rhel8-zseries-small", "rhel8-power-small", "rhel82-arm64-small"]
host_names = ["RHEL9-FIPS", "RHEL8-zseries", "RHEL8-POWER8", "RHEL8-arm64"]
for host, host_name in zip(hosts, host_names):
for host_name in OTHER_HOSTS:
host = HOSTS[host_name]
variants.append(
create_variant(
[".6.0 .standalone !.sync_async"],
display_name=get_display_name(f"Other hosts {host_name}", **expansions),
display_name=get_display_name("Other hosts", host, **expansions),
expansions=expansions,
batchtime=batchtime,
run_on=[host],
host=host,
)
)
return variants

View File

@ -55,7 +55,7 @@ def _get_authenticator(
properties = credentials.mechanism_properties
# Validate that the address is allowed.
if not properties.environment:
if properties.human_callback is not None:
found = False
allowed_hosts = properties.allowed_hosts
for patt in allowed_hosts:

View File

@ -100,8 +100,8 @@ def _validate_canonicalize_host_name(value: str | bool) -> str | bool:
def _build_credentials_tuple(
mech: str,
source: Optional[str],
user: str,
passwd: str,
user: Optional[str],
passwd: Optional[str],
extra: Mapping[str, Any],
database: Optional[str],
) -> MongoCredential:
@ -161,6 +161,8 @@ def _build_credentials_tuple(
"::1",
]
allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed)
if properties.get("ALLOWED_HOSTS", None) is not None and human_callback is None:
raise ConfigurationError("ALLOWED_HOSTS is only valid with OIDC_HUMAN_CALLBACK")
msg = (
"authentication with MONGODB-OIDC requires providing either a callback or a environment"
)
@ -207,7 +209,7 @@ def _build_credentials_tuple(
environment=environ,
allowed_hosts=allowed_hosts,
token_resource=token_resource,
username=user,
username=user or "",
)
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())

View File

@ -873,8 +873,10 @@ def get_validated_options(
validator = _get_validator(opt, URI_OPTIONS_VALIDATOR_MAP, normed_key=normed_key)
validated = validator(opt, value)
except (ValueError, TypeError, ConfigurationError) as exc:
if normed_key == "authmechanismproperties" and any(
p in str(exc) for p in _MECH_PROP_MUST_RAISE
if (
normed_key == "authmechanismproperties"
and any(p in str(exc) for p in _MECH_PROP_MUST_RAISE)
and "is not a supported auth mechanism property" not in str(exc)
):
raise
if warn:

View File

@ -55,7 +55,7 @@ def _get_authenticator(
properties = credentials.mechanism_properties
# Validate that the address is allowed.
if not properties.environment:
if properties.human_callback is not None:
found = False
allowed_hosts = properties.allowed_hosts
for patt in allowed_hosts:

View File

@ -1,5 +1,5 @@
mypy==1.13.0
pyright==1.1.385
pyright==1.1.388
typing_extensions
-r ./encryption.txt
-r ./ocsp.txt

View File

@ -38,11 +38,17 @@ from pymongo import MongoClient
from pymongo._azure_helpers import _get_azure_response
from pymongo._gcp_helpers import _get_gcp_response
from pymongo.auth_oidc_shared import _get_k8s_token
from pymongo.auth_shared import _build_credentials_tuple
from pymongo.cursor_shared import CursorType
from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure
from pymongo.hello import HelloCompat
from pymongo.operations import InsertOne
from pymongo.synchronous.auth_oidc import OIDCCallback, OIDCCallbackContext, OIDCCallbackResult
from pymongo.synchronous.auth_oidc import (
OIDCCallback,
OIDCCallbackContext,
OIDCCallbackResult,
_get_authenticator,
)
from pymongo.uri_parser import parse_uri
ROOT = Path(__file__).parent.parent.resolve()
@ -103,7 +109,6 @@ class OIDCTestBase(PyMongoTestCase):
client.close()
@pytest.mark.auth_oidc
class TestAuthOIDCHuman(OIDCTestBase):
uri: str
@ -838,12 +843,35 @@ class TestAuthOIDCMachine(OIDCTestBase):
self.create_client(authmechanismproperties=props)
def test_2_5_invalid_use_of_ALLOWED_HOSTS(self):
# Create an OIDC configured client with auth mechanism properties `{"ENVIRONMENT": "azure", "ALLOWED_HOSTS": []}`.
props: Dict = {"ENVIRONMENT": "azure", "ALLOWED_HOSTS": []}
# Create an OIDC configured client with auth mechanism properties `{"ENVIRONMENT": "test", "ALLOWED_HOSTS": []}`.
props: Dict = {"ENVIRONMENT": "test", "ALLOWED_HOSTS": []}
# Assert it returns a client configuration error.
with self.assertRaises(ConfigurationError):
self.create_client(authmechanismproperties=props)
# Create an OIDC configured client with auth mechanism properties `{"OIDC_CALLBACK": "<my_callback>", "ALLOWED_HOSTS": []}`.
props: Dict = {"OIDC_CALLBACK": self.create_request_cb(), "ALLOWED_HOSTS": []}
# Assert it returns a client configuration error.
with self.assertRaises(ConfigurationError):
self.create_client(authmechanismproperties=props)
def test_2_6_ALLOWED_HOSTS_defaults_ignored(self):
# Create a MongoCredential for OIDC with a machine callback.
props = {"OIDC_CALLBACK": self.create_request_cb()}
extra = dict(authmechanismproperties=props)
mongo_creds = _build_credentials_tuple("MONGODB-OIDC", None, "foo", None, extra, "test")
# Assert that creating an authenticator for example.com does not result in an error.
authenticator = _get_authenticator(mongo_creds, ("example.com", 30))
assert authenticator.properties.username == "foo"
# Create a MongoCredential for OIDC with an ENVIRONMENT.
props = {"ENVIRONMENT": "test"}
extra = dict(authmechanismproperties=props)
mongo_creds = _build_credentials_tuple("MONGODB-OIDC", None, None, None, extra, "test")
# Assert that creating an authenticator for example.com does not result in an error.
authenticator = _get_authenticator(mongo_creds, ("example.com", 30))
assert authenticator.properties.username == ""
def test_3_1_authentication_failure_with_cached_tokens_fetch_a_new_token_and_retry(self):
# Create a MongoClient and an OIDC callback that implements the provider logic.
client = self.create_client()
@ -909,7 +937,7 @@ class TestAuthOIDCMachine(OIDCTestBase):
# Assert that the callback has been called once.
self.assertEqual(self.request_called, 1)
def test_4_1_reauthentication_succeds(self):
def test_4_1_reauthentication_succeeds(self):
# Create a ``MongoClient`` configured with a custom OIDC callback that
# implements the provider logic.
client = self.create_client()