Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2024-04-10 16:45:38 -05:00
commit 059c19f37a
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
24 changed files with 1269 additions and 1156 deletions

View File

@ -991,6 +991,30 @@ task_groups:
tasks:
- oidc-auth-test-azure-latest
- name: testgcpoidc_task_group
setup_group:
- func: fetch source
- func: prepare resources
- func: fix absolute paths
- func: make files executable
- command: subprocess.exec
params:
binary: bash
env:
GCPOIDC_VMNAME_PREFIX: "PYTHON_DRIVER"
args:
- ${DRIVERS_TOOLS}/.evergreen/auth_oidc/gcp/setup.sh
teardown_task:
- command: subprocess.exec
params:
binary: bash
args:
- ${DRIVERS_TOOLS}/.evergreen/auth_oidc/gcp/teardown.sh
setup_group_can_fail_task: true
setup_group_timeout_secs: 1800
tasks:
- oidc-auth-test-gcp-latest
- name: testoidc_task_group
setup_group:
- func: fetch source
@ -1966,6 +1990,25 @@ tasks:
export AZUREOIDC_TEST_CMD="OIDC_ENV=azure ./.evergreen/run-mongodb-oidc-test.sh"
bash $DRIVERS_TOOLS/.evergreen/auth_oidc/azure/run-driver-test.sh
- name: "oidc-auth-test-gcp-latest"
commands:
- command: shell.exec
params:
shell: bash
script: |-
set -o errexit
${PREPARE_SHELL}
cd src
git add .
git commit -m "add files"
export GCPOIDC_DRIVERS_TAR_FILE=/tmp/mongo-python-driver.tgz
git archive -o $GCPOIDC_DRIVERS_TAR_FILE HEAD
# Define the command to run on the VM.
# Ensure that we source the environment file created for us, set up any other variables we need,
# and then run our test suite on the vm.
export GCPOIDC_TEST_CMD="OIDC_ENV=gcp ./.evergreen/run-mongodb-oidc-test.sh"
bash $DRIVERS_TOOLS/.evergreen/auth_oidc/gcp/run-driver-test.sh
- name: "test-fips-standalone"
tags: ["fips"]
commands:
@ -2995,18 +3038,25 @@ buildvariants:
- matrix_name: "oidc-auth-test"
matrix_spec:
platform: [ rhel8, macos-1100, windows-64-vsMulti-small ]
display_name: "MONGODB-OIDC Auth ${platform}"
display_name: "OIDC Auth ${platform}"
tasks:
- name: testoidc_task_group
batchtime: 20160 # 14 days
- name: testazureoidc-variant
display_name: "Azure OIDC"
run_on: ubuntu2004-small
display_name: "OIDC Auth Azure"
run_on: ubuntu2204-small
tasks:
- name: testazureoidc_task_group
batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README
- name: testgcpoidc-variant
display_name: "OIDC Auth GCP"
run_on: ubuntu2204-small
tasks:
- name: testgcpoidc_task_group
batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README
- matrix_name: "aws-auth-test"
matrix_spec:
platform: [ubuntu-20.04]

View File

@ -1,7 +1,7 @@
#!/bin/bash
set +x # Disable debug trace
set -o errexit # Exit the script with error if any of the commands fail
set -eu
echo "Running MONGODB-OIDC authentication tests"
@ -18,6 +18,9 @@ if [ $OIDC_ENV == "test" ]; then
elif [ $OIDC_ENV == "azure" ]; then
source ./env.sh
elif [ $OIDC_ENV == "gcp" ]; then
source ./secrets-export.sh
else
echo "Unrecognized OIDC_ENV $OIDC_ENV"
exit 1

View File

@ -41,6 +41,8 @@ PyMongo 4.7 brings a number of improvements including:
:attr:`pymongo.monitoring.ConnectionReadyEvent.duration` properties.
- Added the ``type`` and ``kwargs`` arguments to :class:`~pymongo.operations.SearchIndexModel` to enable
creating vector search indexes in MongoDB Atlas.
- Fixed a bug where ``read_concern`` and ``write_concern`` were improperly added to
:meth:`~pymongo.collection.Collection.list_search_indexes` queries.
Unavoidable breaking changes

39
pymongo/_gcp_helpers.py Normal file
View File

@ -0,0 +1,39 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GCP helpers."""
from __future__ import annotations
from typing import Any
from urllib.request import Request, urlopen
def _get_gcp_response(resource: str, timeout: float = 5) -> dict[str, Any]:
url = "http://metadata/computeMetadata/v1/instance/service-accounts/default/identity"
url += f"?audience={resource}"
headers = {"Metadata-Flavor": "Google", "Accept": "application/json"}
request = Request(url, headers=headers) # noqa: S310
try:
with urlopen(request, timeout=timeout) as response: # noqa: S310
status = response.status
body = response.read().decode("utf8")
except Exception as e:
msg = "Failed to acquire IMDS access token: %s" % e
raise ValueError(msg) from None
if status != 200:
msg = "Failed to acquire IMDS access token."
raise ValueError(msg)
return dict(access_token=body)

View File

@ -41,6 +41,7 @@ from pymongo.auth_oidc import (
_authenticate_oidc,
_get_authenticator,
_OIDCAzureCallback,
_OIDCGCPCallback,
_OIDCProperties,
_OIDCTestCallback,
)
@ -207,6 +208,13 @@ def _build_credentials_tuple(
"Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
)
callback = _OIDCAzureCallback(token_resource)
elif environ == "gcp":
passwd = None
if not token_resource:
raise ConfigurationError(
"GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
)
callback = _OIDCGCPCallback(token_resource)
else:
raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}")
else:

View File

@ -26,7 +26,9 @@ import bson
from bson.binary import Binary
from pymongo._azure_helpers import _get_azure_response
from pymongo._csot import remaining
from pymongo._gcp_helpers import _get_gcp_response
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.helpers import _AUTHENTICATION_FAILURE_CODE
if TYPE_CHECKING:
from pymongo.auth import MongoCredential
@ -36,7 +38,7 @@ if TYPE_CHECKING:
@dataclass
class OIDCIdPInfo:
issuer: str
clientId: str
clientId: Optional[str] = field(default=None)
requestScopes: Optional[list[str]] = field(default=None)
@ -133,6 +135,15 @@ class _OIDCAzureCallback(OIDCCallback):
)
class _OIDCGCPCallback(OIDCCallback):
def __init__(self, token_resource: str) -> None:
self.token_resource = token_resource
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
resp = _get_gcp_response(self.token_resource, context.timeout_seconds)
return OIDCCallbackResult(access_token=resp["access_token"])
@dataclass
class _OIDCAuthenticator:
username: str
@ -179,30 +190,43 @@ class _OIDCAuthenticator:
def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]:
# If there is a cached access token, try to authenticate with it. If
# authentication fails, it's possible the cached access token is expired. In
# that case, invalidate the access token, fetch a new access token, and try
# to authenticate again.
# authentication fails with error code 18, invalidate the access token,
# fetch a new access token, and try to authenticate again. If authentication
# fails for any other reason, raise the error to the user.
if self.access_token:
try:
return self._sasl_start_jwt(conn)
except Exception: # noqa: S110
pass
except OperationFailure as e:
if self._is_auth_error(e):
return self._authenticate_machine(conn)
raise
return self._sasl_start_jwt(conn)
def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]:
# If we have a cached access token, try a JwtStepRequest.
# authentication fails with error code 18, invalidate the access token,
# and try to authenticate again. If authentication fails for any other
# reason, raise the error to the user.
if self.access_token:
try:
return self._sasl_start_jwt(conn)
except Exception: # noqa: S110
pass
except OperationFailure as e:
if self._is_auth_error(e):
return self._authenticate_human(conn)
raise
# If we have a cached refresh token, try a JwtStepRequest with that.
# If authentication fails with error code 18, invalidate the access and
# refresh tokens, and try to authenticate again. If authentication fails for
# any other reason, raise the error to the user.
if self.refresh_token:
try:
return self._sasl_start_jwt(conn)
except Exception: # noqa: S110
pass
except OperationFailure as e:
if self._is_auth_error(e):
self.refresh_token = None
return self._authenticate_human(conn)
raise
# Start a new Two-Step SASL conversation.
# Run a PrincipalStepRequest to get the IdpInfo.
@ -270,10 +294,16 @@ class _OIDCAuthenticator:
def _run_command(self, conn: Connection, cmd: MutableMapping[str, Any]) -> Mapping[str, Any]:
try:
return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
except OperationFailure:
self._invalidate(conn)
except OperationFailure as e:
if self._is_auth_error(e):
self._invalidate(conn)
raise
def _is_auth_error(self, err: Exception) -> bool:
if not isinstance(err, OperationFailure):
return False
return err.code == _AUTHENTICATION_FAILURE_CODE
def _invalidate(self, conn: Connection) -> None:
# Ignore the invalidation if a token gen id is given and is less than our
# current token gen id.

View File

@ -72,6 +72,7 @@ from pymongo.operations import (
_IndexList,
_Op,
)
from pymongo.read_concern import DEFAULT_READ_CONCERN, ReadConcern
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.results import (
BulkWriteResult,
@ -81,7 +82,7 @@ from pymongo.results import (
UpdateResult,
)
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
from pymongo.write_concern import WriteConcern, validate_boolean
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean
T = TypeVar("T")
@ -119,7 +120,6 @@ if TYPE_CHECKING:
from pymongo.collation import Collation
from pymongo.database import Database
from pymongo.pool import Connection
from pymongo.read_concern import ReadConcern
from pymongo.server import Server
@ -2364,7 +2364,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
pipeline = [{"$listSearchIndexes": {"name": name}}]
coll = self.with_options(
codec_options=DEFAULT_CODEC_OPTIONS, read_preference=ReadPreference.PRIMARY
codec_options=DEFAULT_CODEC_OPTIONS,
read_preference=ReadPreference.PRIMARY,
write_concern=DEFAULT_WRITE_CONCERN,
read_concern=DEFAULT_READ_CONCERN,
)
cmd = _CollectionAggregationCommand(
coll,

View File

@ -426,7 +426,6 @@ _MECHANISM_PROPS = frozenset(
"AWS_SESSION_TOKEN",
"ENVIRONMENT",
"TOKEN_RESOURCE",
"ALLOWED_HOSTS",
]
)

View File

@ -90,6 +90,9 @@ _RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
# Server code raised when re-authentication is required
_REAUTHENTICATION_REQUIRED_CODE: int = 391
# Server code raised when authentication fails.
_AUTHENTICATION_FAILURE_CODE: int = 18
def _gen_index_name(keys: _IndexList) -> str:
"""Generate an index name from the set of fields it is over."""

View File

@ -335,8 +335,12 @@ class SrvMonitor(MonitorBase):
self._seedlist = self._settings._seeds
assert isinstance(self._settings.fqdn, str)
self._fqdn: str = self._settings.fqdn
self._startup_time = time.monotonic()
def _run(self) -> None:
# Don't poll right after creation, wait 60 seconds first
if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL:
return
seedlist = self._get_seedlist()
if seedlist:
self._seedlist = seedlist

View File

@ -592,7 +592,7 @@ class SearchIndexModel:
self,
definition: Mapping[str, Any],
name: Optional[str] = None,
type: Optional[str] = "search",
type: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Create a Search Index instance.
@ -613,7 +613,8 @@ class SearchIndexModel:
if name is not None:
self.__document["name"] = name
self.__document["definition"] = definition
self.__document["type"] = type
if type is not None:
self.__document["type"] = type
self.__document.update(kwargs)
@property

View File

@ -265,8 +265,14 @@ class TopologyDescription:
def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerDescription]:
if not selection:
return []
round_trip_times: list[float] = []
for server in selection.server_descriptions:
if server.round_trip_time is None:
config_err_msg = f"round_trip_time for server {server.address} is unexpectedly None: {self}, servers: {selection.server_descriptions}"
raise ConfigurationError(config_err_msg)
round_trip_times.append(server.round_trip_time)
# Round trip time in seconds.
fastest = min(cast(float, s.round_trip_time) for s in selection.server_descriptions)
fastest = min(round_trip_times)
threshold = self._topology_settings.local_threshold_ms / 1000.0
return [
s

View File

@ -497,6 +497,12 @@
"valid": false,
"credential": null
},
{
"description": "should throw an exception custom callback is chosen but no callback is provided (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:custom",
"valid": false,
"credential": null
},
{
"description": "should throw an exception if neither provider nor callbacks specified (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
@ -540,10 +546,37 @@
"credential": null
},
{
"description": "should throw and exception if no token audience is given for azure provider (MONGODB-OIDC)",
"description": "should throw an exception if no token audience is given for azure provider (MONGODB-OIDC)",
"uri": "mongodb://username@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure",
"valid": false,
"credential": null
},
{
"description": "should recognise the mechanism with gcp provider (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:gcp,TOKEN_RESOURCE:foo",
"valid": true,
"credential": {
"username": null,
"password": null,
"source": "$external",
"mechanism": "MONGODB-OIDC",
"mechanism_properties": {
"ENVIRONMENT": "gcp",
"TOKEN_RESOURCE": "foo"
}
}
},
{
"description": "should throw an error for a username and password with gcp provider (MONGODB-OIDC)",
"uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:gcp,TOKEN_RESOURCE:foo",
"valid": false,
"credential": null
},
{
"description": "should throw an error if not TOKEN_RESOURCE with gcp provider (MONGODB-OIDC)",
"uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:gcp",
"valid": false,
"credential": null
}
]
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,191 +0,0 @@
{
"description": "reauthenticate_with_retry",
"schemaVersion": "1.12",
"runOnRequirements": [
{
"minServerVersion": "6.3",
"auth": true
}
],
"createEntities": [
{
"client": {
"id": "client0",
"uriOptions": {
"retryReads": true,
"retryWrites": true
},
"observeEvents": [
"commandStartedEvent",
"commandSucceededEvent",
"commandFailedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "db"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "collName"
}
}
],
"initialData": [
{
"collectionName": "collName",
"databaseName": "db",
"documents": []
}
],
"tests": [
{
"description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=true",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"find"
],
"errorCode": 391
}
}
}
},
{
"name": "find",
"arguments": {
"filter": {}
},
"object": "collection0",
"expectResult": []
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"find": "collName",
"filter": {}
}
}
},
{
"commandFailedEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"command": {
"find": "collName",
"filter": {}
}
}
},
{
"commandSucceededEvent": {
"commandName": "find"
}
}
]
}
]
},
{
"description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"insert"
],
"errorCode": 391
}
}
}
},
{
"name": "insertOne",
"object": "collection0",
"arguments": {
"document": {
"_id": 1,
"x": 1
}
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "collName",
"documents": [
{
"_id": 1,
"x": 1
}
]
}
}
},
{
"commandFailedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"command": {
"insert": "collName",
"documents": [
{
"_id": 1,
"x": 1
}
]
}
}
},
{
"commandSucceededEvent": {
"commandName": "insert"
}
}
]
}
]
}
]
}

View File

@ -1,191 +0,0 @@
{
"description": "reauthenticate_without_retry",
"schemaVersion": "1.12",
"runOnRequirements": [
{
"minServerVersion": "6.3",
"auth": true
}
],
"createEntities": [
{
"client": {
"id": "client0",
"uriOptions": {
"retryReads": false,
"retryWrites": false
},
"observeEvents": [
"commandStartedEvent",
"commandSucceededEvent",
"commandFailedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "db"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "collName"
}
}
],
"initialData": [
{
"collectionName": "collName",
"databaseName": "db",
"documents": []
}
],
"tests": [
{
"description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=false",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"find"
],
"errorCode": 391
}
}
}
},
{
"name": "find",
"arguments": {
"filter": {}
},
"object": "collection0",
"expectResult": []
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"find": "collName",
"filter": {}
}
}
},
{
"commandFailedEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"command": {
"find": "collName",
"filter": {}
}
}
},
{
"commandSucceededEvent": {
"commandName": "find"
}
}
]
}
]
},
{
"description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=false",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"insert"
],
"errorCode": 391
}
}
}
},
{
"name": "insertOne",
"object": "collection0",
"arguments": {
"document": {
"_id": 1,
"x": 1
}
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "collName",
"documents": [
{
"_id": 1,
"x": 1
}
]
}
}
},
{
"commandFailedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"command": {
"insert": "collName",
"documents": [
{
"_id": 1,
"x": 1
}
]
}
}
},
{
"commandSucceededEvent": {
"commandName": "insert"
}
}
]
}
]
}
]
}

View File

@ -27,16 +27,15 @@ from typing import Dict
sys.path[0:0] = [""]
import pprint
from test.unified_format import generate_test_classes
from test.utils import EventListener
from bson import SON
from pymongo import MongoClient
from pymongo._azure_helpers import _get_azure_response
from pymongo.auth_oidc import (
OIDCCallback,
OIDCCallbackResult,
)
from pymongo._gcp_helpers import _get_gcp_response
from pymongo.auth_oidc import OIDCCallback, OIDCCallbackContext, OIDCCallbackResult
from pymongo.cursor import CursorType
from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure
from pymongo.hello import HelloCompat
@ -75,10 +74,12 @@ class OIDCTestBase(unittest.TestCase):
return fid.read()
elif ENVIRON == "azure":
opts = parse_uri(self.uri_single)["options"]
resource = opts["authmechanismproperties"]["TOKEN_RESOURCE"]
return _get_azure_response(resource, username)["access_token"]
else:
raise RuntimeError(f"Invalid ENVIRONMENT {ENVIRON}")
token_aud = opts["authmechanismproperties"]["TOKEN_RESOURCE"]
return _get_azure_response(token_aud, username)["access_token"]
elif ENVIRON == "gcp":
opts = parse_uri(self.uri_single)["options"]
token_aud = opts["authmechanismproperties"]["TOKEN_RESOURCE"]
return _get_gcp_response(token_aud, username)["access_token"]
@contextmanager
def fail_point(self, command_args):
@ -103,15 +104,24 @@ class TestAuthOIDCHuman(OIDCTestBase):
raise ValueError("Missing OIDC_DOMAIN")
super().setUpClass()
def setUp(self):
self.refresh_present = 0
super().setUp()
def create_request_cb(self, username="test_user1", sleep=0):
def request_token(context):
def request_token(context: OIDCCallbackContext):
# Validate the info.
self.assertIsInstance(context.idp_info.issuer, str)
self.assertIsInstance(context.idp_info.clientId, str)
if context.idp_info.clientId is not None:
self.assertIsInstance(context.idp_info.clientId, str)
# Validate the timeout.
timeout_seconds = context.timeout_seconds
self.assertEqual(timeout_seconds, 60 * 5)
if context.refresh_token:
self.refresh_present += 1
token = self.get_token(username)
resp = OIDCCallbackResult(access_token=token, refresh_token=token)
@ -127,7 +137,7 @@ class TestAuthOIDCHuman(OIDCTestBase):
def create_client(self, *args, **kwargs):
username = kwargs.get("username", "test_user1")
if kwargs.get("username"):
if kwargs.get("username") in ["test_user1", "test_user2"]:
kwargs["username"] = f"{username}@{DOMAIN}"
request_cb = kwargs.pop("request_cb", self.create_request_cb(username=username))
props = kwargs.pop("authmechanismproperties", {"OIDC_HUMAN_CALLBACK": request_cb})
@ -215,6 +225,26 @@ class TestAuthOIDCHuman(OIDCTestBase):
# Close the client.
client.close()
def test_1_7_allowed_hosts_in_connection_string_ignored(self):
# Create an OIDC configured client with the connection string: `mongodb+srv://example.com/?authMechanism=MONGODB-OIDC&authMechanismProperties=ALLOWED_HOSTS:%5B%22example.com%22%5D` and a Human Callback.
# Assert that the creation of the client raises a configuration error.
uri = "mongodb+srv://example.com?authMechanism=MONGODB-OIDC&authMechanismProperties=ALLOWED_HOSTS:%5B%22example.com%22%5D"
with self.assertRaises(ConfigurationError), warnings.catch_warnings():
warnings.simplefilter("ignore")
_ = MongoClient(
uri, authmechanismproperties=dict(OIDC_HUMAN_CALLBACK=self.create_request_cb())
)
def test_1_8_machine_idp_human_callback(self):
if not os.environ.get("OIDC_IS_LOCAL"):
raise unittest.SkipTest("Test Requires Local OIDC server")
# Create a client with MONGODB_URI_SINGLE, a username of test_machine, authMechanism=MONGODB-OIDC, and the OIDC human callback.
client = self.create_client(username="test_machine")
# Perform a find operation that succeeds.
client.test.test.find_one()
# Close the client.
client.close()
def test_2_1_valid_callback_inputs(self):
# Create a MongoClient with a human callback that validates its inputs and returns a valid access token.
client = self.create_client()
@ -224,7 +254,7 @@ class TestAuthOIDCHuman(OIDCTestBase):
# Close the client.
client.close()
def test_2_2_OIDC_HUMAN_CALLBACK_returns_missing_data(self):
def test_2_2_callback_returns_missing_data(self):
# Create a MongoClient with a human callback that returns data not conforming to the OIDCCredential with missing fields.
class CustomCB(OIDCCallback):
def fetch(self, ctx):
@ -237,6 +267,29 @@ class TestAuthOIDCHuman(OIDCTestBase):
# Close the client.
client.close()
def test_2_3_refresh_token_is_passed_to_the_callback(self):
# Create a MongoClient with a human callback that checks for the presence of a refresh token.
client = self.create_client()
# Perform a find operation that succeeds.
client.test.test.find_one()
# Set a fail point for ``find`` commands.
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a ``find`` operation that succeeds.
client.test.test.find_one()
# Assert that the callback has been called twice.
self.assertEqual(self.request_called, 2)
# Assert that the refresh token was used once.
self.assertEqual(self.refresh_present, 1)
def test_3_1_uses_speculative_authentication_if_there_is_a_cached_token(self):
# Create a client with a human callback that returns a valid token.
client = self.create_client()
@ -255,8 +308,8 @@ class TestAuthOIDCHuman(OIDCTestBase):
# Set a fail point for ``saslStart`` commands.
with self.fail_point(
{
"mode": "alwaysOn",
"data": {"failCommands": ["saslStart"], "errorCode": 20},
"mode": {"times": 1},
"data": {"failCommands": ["saslStart"], "errorCode": 18},
}
):
# Perform a ``find`` operation that succeeds
@ -272,8 +325,8 @@ class TestAuthOIDCHuman(OIDCTestBase):
# Set a fail point for ``saslStart`` commands.
with self.fail_point(
{
"mode": "alwaysOn",
"data": {"failCommands": ["saslStart"], "errorCode": 20},
"mode": {"times": 1},
"data": {"failCommands": ["saslStart"], "errorCode": 18},
}
):
# Perform a ``find`` operation that fails.
@ -374,8 +427,16 @@ class TestAuthOIDCHuman(OIDCTestBase):
client.close()
def test_4_3_reauthenticate_succeeds_after_refresh_fails(self):
# Create a client with a human callback that returns a valid token.
client = self.create_client()
# Create a default OIDC client with a human callback that returns an invalid refresh token
cb = self.create_request_cb()
class CustomRequest(OIDCCallback):
def fetch(self, *args, **kwargs):
result = cb.fetch(*args, **kwargs)
result.refresh_token = "bad"
return result
client = self.create_client(request_cb=CustomRequest())
# Perform a find operation that succeeds.
client.test.test.find_one()
@ -386,38 +447,56 @@ class TestAuthOIDCHuman(OIDCTestBase):
# Force a reauthenication using a fail point.
with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["find", "saslStart"], "errorCode": 391},
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a find operation that succeeds.
client.test.test.find_one()
# Assert that the human callback has been called 3 times.
self.assertEqual(self.request_called, 3)
# Assert that the human callback has been called 2 times.
self.assertEqual(self.request_called, 2)
# Close the client.
client.close()
def test_4_4_reauthenticate_fails(self):
# Create a client with a human callback that returns a valid token.
client = self.create_client()
# Create a default OIDC client with a human callback that returns invalid refresh tokens and
# Returns invalid access tokens after the first access.
cb = self.create_request_cb()
class CustomRequest(OIDCCallback):
fetch_called = 0
def fetch(self, *args, **kwargs):
self.fetch_called += 1
result = cb.fetch(*args, **kwargs)
result.refresh_token = "bad"
if self.fetch_called > 1:
result.access_token = "bad"
return result
client = self.create_client(request_cb=CustomRequest())
# Perform a find operation that succeeds (to force a speculative auth).
client.test.test.find_one()
# Assert that the human callback has been called once.
self.assertEqual(self.request_called, 1)
# Force a reauthentication using a failCommand.
with self.fail_point(
{
"mode": {"times": 3},
"data": {"failCommands": ["find", "saslStart"], "errorCode": 391},
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a find operation that fails.
with self.assertRaises(OperationFailure):
client.test.test.find_one()
# Assert that the human callback has been called two times.
self.assertEqual(self.request_called, 2)
# Assert that the human callback has been called three times.
self.assertEqual(self.request_called, 3)
# Close the client.
client.close()
@ -814,6 +893,33 @@ class TestAuthOIDCMachine(OIDCTestBase):
# Close the client.
client.close()
def test_3_3_unexpected_error_code_does_not_clear_cache(self):
# Create a ``MongoClient`` with a human callback that returns a valid token
client = self.create_client()
# Set a fail point for ``saslStart`` commands.
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["saslStart"], "errorCode": 20},
}
):
# Perform a ``find`` operation that fails.
with self.assertRaises(OperationFailure):
client.test.test.find_one()
# Assert that the callback has been called once.
self.assertEqual(self.request_called, 1)
# Perform a ``find`` operation that succeeds.
client.test.test.find_one()
# Assert that the callback has been called once.
self.assertEqual(self.request_called, 1)
# Close the client.
client.close()
def test_4_reauthentication(self):
# Create a ``MongoClient`` configured with a custom OIDC callback that
# implements the provider logic.

View File

@ -50,8 +50,7 @@
"mappings": {
"dynamic": true
}
},
"type": "search"
}
}
},
"expectError": {
@ -74,8 +73,7 @@
"mappings": {
"dynamic": true
}
},
"type": "search"
}
}
],
"$db": "database0"
@ -99,8 +97,7 @@
"dynamic": true
}
},
"name": "test index",
"type": "search"
"name": "test index"
}
},
"expectError": {
@ -124,68 +121,7 @@
"dynamic": true
}
},
"name": "test index",
"type": "search"
}
],
"$db": "database0"
}
}
}
]
}
]
},
{
"description": "create a vector search index",
"operations": [
{
"name": "createSearchIndex",
"object": "collection0",
"arguments": {
"model": {
"definition": {
"fields": [
{
"type": "vector",
"path": "plot_embedding",
"numDimensions": 1536,
"similarity": "euclidean"
}
]
},
"name": "test index",
"type": "vectorSearch"
}
},
"expectError": {
"isError": true,
"errorContains": "Atlas"
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"createSearchIndexes": "collection0",
"indexes": [
{
"definition": {
"fields": [
{
"type": "vector",
"path": "plot_embedding",
"numDimensions": 1536,
"similarity": "euclidean"
}
]
},
"name": "test index",
"type": "vectorSearch"
"name": "test index"
}
],
"$db": "database0"

View File

@ -83,8 +83,7 @@
"mappings": {
"dynamic": true
}
},
"type": "search"
}
}
]
},
@ -108,8 +107,7 @@
"mappings": {
"dynamic": true
}
},
"type": "search"
}
}
],
"$db": "database0"
@ -134,8 +132,7 @@
"dynamic": true
}
},
"name": "test index",
"type": "search"
"name": "test index"
}
]
},
@ -160,70 +157,7 @@
"dynamic": true
}
},
"name": "test index",
"type": "search"
}
],
"$db": "database0"
}
}
}
]
}
]
},
{
"description": "create a vector search index",
"operations": [
{
"name": "createSearchIndexes",
"object": "collection0",
"arguments": {
"models": [
{
"definition": {
"fields": [
{
"type": "vector",
"path": "plot_embedding",
"numDimensions": 1536,
"similarity": "euclidean"
}
]
},
"name": "test index",
"type": "vectorSearch"
}
]
},
"expectError": {
"isError": true,
"errorContains": "Atlas"
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"createSearchIndexes": "collection0",
"indexes": [
{
"definition": {
"fields": [
{
"type": "vector",
"path": "plot_embedding",
"numDimensions": 1536,
"similarity": "euclidean"
}
]
},
"name": "test index",
"type": "vectorSearch"
"name": "test index"
}
],
"$db": "database0"

View File

@ -0,0 +1,252 @@
{
"description": "search index operations ignore read and write concern",
"schemaVersion": "1.4",
"createEntities": [
{
"client": {
"id": "client0",
"useMultipleMongoses": false,
"uriOptions": {
"readConcernLevel": "local",
"w": 1
},
"observeEvents": [
"commandStartedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "database0"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "collection0"
}
}
],
"runOnRequirements": [
{
"minServerVersion": "7.0.0",
"topologies": [
"replicaset",
"load-balanced",
"sharded"
],
"serverless": "forbid"
}
],
"tests": [
{
"description": "createSearchIndex ignores read and write concern",
"operations": [
{
"name": "createSearchIndex",
"object": "collection0",
"arguments": {
"model": {
"definition": {
"mappings": {
"dynamic": true
}
}
}
},
"expectError": {
"isError": true,
"errorContains": "Atlas"
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"createSearchIndexes": "collection0",
"indexes": [
{
"definition": {
"mappings": {
"dynamic": true
}
}
}
],
"$db": "database0",
"writeConcern": {
"$$exists": false
},
"readConcern": {
"$$exists": false
}
}
}
}
]
}
]
},
{
"description": "createSearchIndexes ignores read and write concern",
"operations": [
{
"name": "createSearchIndexes",
"object": "collection0",
"arguments": {
"models": []
},
"expectError": {
"isError": true,
"errorContains": "Atlas"
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"createSearchIndexes": "collection0",
"indexes": [],
"$db": "database0",
"writeConcern": {
"$$exists": false
},
"readConcern": {
"$$exists": false
}
}
}
}
]
}
]
},
{
"description": "dropSearchIndex ignores read and write concern",
"operations": [
{
"name": "dropSearchIndex",
"object": "collection0",
"arguments": {
"name": "test index"
},
"expectError": {
"isError": true,
"errorContains": "Atlas"
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"dropSearchIndex": "collection0",
"name": "test index",
"$db": "database0",
"writeConcern": {
"$$exists": false
},
"readConcern": {
"$$exists": false
}
}
}
}
]
}
]
},
{
"description": "listSearchIndexes ignores read and write concern",
"operations": [
{
"name": "listSearchIndexes",
"object": "collection0",
"expectError": {
"isError": true,
"errorContains": "Atlas"
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"aggregate": "collection0",
"pipeline": [
{
"$listSearchIndexes": {}
}
],
"writeConcern": {
"$$exists": false
},
"readConcern": {
"$$exists": false
}
}
}
}
]
}
]
},
{
"description": "updateSearchIndex ignores the read and write concern",
"operations": [
{
"name": "updateSearchIndex",
"object": "collection0",
"arguments": {
"name": "test index",
"definition": {}
},
"expectError": {
"isError": true,
"errorContains": "Atlas"
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"updateSearchIndex": "collection0",
"name": "test index",
"definition": {},
"$db": "database0",
"writeConcern": {
"$$exists": false
},
"readConcern": {
"$$exists": false
}
}
}
}
]
}
]
}
]
}

View File

@ -1,4 +1,4 @@
# Copyright 2015 MongoDB, Inc.
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,7 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the MongoDB Driver Performance Benchmarking Spec."""
"""Tests for the MongoDB Driver Performance Benchmarking Spec.
See https://github.com/mongodb/specifications/blob/master/source/benchmarking/benchmarking.md
To set up the benchmarks locally::
python -m pip install simplejson
git clone --depth 1 https://github.com/mongodb/specifications.git
pushd specifications/source/benchmarking/data
tar xf extended_bson.tgz
tar xf parallel.tgz
tar xf single_and_multi_document.tgz
popd
export TEST_PATH="specifications/source/benchmarking/data"
export OUTPUT_FILE="results.json"
Then to run all benchmarks quickly::
FASTBENCH=1 python test/performance/perf_test.py -v
To run individual benchmarks quickly::
FASTBENCH=1 python test/performance/perf_test.py -v TestRunCommand TestFindManyAndEmptyCursor
"""
from __future__ import annotations
import multiprocessing as mp
@ -36,9 +60,18 @@ from bson import decode, encode, json_util
from gridfs import GridFSBucket
from pymongo import MongoClient
# Spec says to use at least 1 minute cumulative execution time and up to 100 iterations or 5 minutes but that
# makes the benchmarks too slow. Instead, we use at least 30 seconds and at most 60 seconds.
NUM_ITERATIONS = 100
MAX_ITERATION_TIME = 300
MIN_ITERATION_TIME = 30
MAX_ITERATION_TIME = 60
NUM_DOCS = 10000
# When debugging or prototyping it's often useful to run the benchmarks locally, set FASTBENCH=1 to run quickly.
if bool(os.getenv("FASTBENCH")):
NUM_ITERATIONS = 2
MIN_ITERATION_TIME = 0.1
MAX_ITERATION_TIME = 0.5
NUM_DOCS = 1000
TEST_PATH = os.environ.get(
"TEST_PATH", os.path.join(os.path.dirname(os.path.realpath(__file__)), os.path.join("data"))
@ -88,7 +121,7 @@ class PerformanceTest:
megabytes_per_sec = self.data_size / median / 1000000
print(
f"Completed {self.__class__.__name__} {megabytes_per_sec:.3f} MB/s, MEDIAN={self.percentile(50):.3f}s, "
f"total time={duration:.3f}s"
f"total time={duration:.3f}s, iterations={len(self.results)}"
)
result_data.append(
{
@ -125,19 +158,25 @@ class PerformanceTest:
def runTest(self):
results = []
start = time.monotonic()
for i in range(NUM_ITERATIONS):
if time.monotonic() - start > MAX_ITERATION_TIME:
with warnings.catch_warnings():
warnings.simplefilter("default")
warnings.warn(
f"Test timed out after {MAX_ITERATION_TIME}s, completed {i}/{NUM_ITERATIONS} iterations."
)
break
i = 0
while True:
i += 1
self.before()
with Timer() as timer:
self.do_task()
self.after()
results.append(timer.interval)
duration = time.monotonic() - start
if duration > MIN_ITERATION_TIME and i >= NUM_ITERATIONS:
break
if duration > MAX_ITERATION_TIME:
with warnings.catch_warnings():
warnings.simplefilter("default")
warnings.warn(
f"{self.__class__.__name__} timed out after {MAX_ITERATION_TIME}s, completed {i}/{NUM_ITERATIONS} iterations."
)
break
self.results = results

View File

@ -30,6 +30,8 @@ from test.utils import AllowListEventListener, EventListener
from pymongo import MongoClient
from pymongo.errors import OperationFailure
from pymongo.operations import SearchIndexModel
from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "index_management")
@ -60,7 +62,17 @@ class TestCreateSearchIndex(IntegrationTest):
listener.reset()
with self.assertRaises(OperationFailure):
coll.create_search_index({"definition": definition, "arbitraryOption": 1})
self.assertIn("arbitraryOption", listener.events[0].command["indexes"][0])
self.assertEqual(
{"definition": definition, "arbitraryOption": 1},
listener.events[0].command["indexes"][0],
)
listener.reset()
with self.assertRaises(OperationFailure):
coll.create_search_index({"definition": definition, "type": "search"})
self.assertEqual(
{"definition": definition, "type": "search"}, listener.events[0].command["indexes"][0]
)
class SearchIndexIntegrationBase(unittest.TestCase):
@ -257,6 +269,33 @@ class TestSearchIndexProse(SearchIndexIntegrationBase):
# Run a ``dropSearchIndex`` command and assert that no error is thrown.
coll0.drop_search_index("foo")
def test_case_6(self):
"""Driver can successfully create and list search indexes with non-default readConcern and writeConcern."""
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
coll0 = self.db[f"col{uuid.uuid4()}"]
coll0.insert_one({})
# Apply a write concern ``WriteConcern(w=1)`` and a read concern with ``ReadConcern(level="majority")`` to ``coll0``.
coll0 = coll0.with_options(
write_concern=WriteConcern(w="1"), read_concern=ReadConcern(level="majority")
)
# Create a new search index on ``coll0`` with the ``createSearchIndex`` helper.
name = "test-search-index-case6"
model = {"name": name, "definition": {"mappings": {"dynamic": False}}}
resp = coll0.create_search_index(model)
# Assert that the command returns the name of the index: ``"test-search-index-case6"``.
self.assertEqual(resp, name)
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied and store the value in a variable ``index``:
# - An index with the ``name`` of ``test-search-index-case6`` is present and the index has a field ``queryable`` with a value of ``true``.
index = self.wait_for_ready(coll0, name)
# Assert that ``index`` has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': false } }``
self.assertIn("latestDefinition", index)
self.assertEqual(index["latestDefinition"], model["definition"])
def test_case_7(self):
"""Driver handles index types."""

View File

@ -111,7 +111,7 @@ class TestSrvPolling(unittest.TestCase):
def get_nodelist(self, client):
return client._topology.description.server_descriptions().keys()
def assert_nodelist_change(self, expected_nodelist, client):
def assert_nodelist_change(self, expected_nodelist, client, timeout=(100 * WAIT_TIME)):
"""Check if the client._topology eventually sees all nodes in the
expected_nodelist.
"""
@ -122,9 +122,9 @@ class TestSrvPolling(unittest.TestCase):
return True
return False
wait_until(predicate, "see expected nodelist", timeout=100 * WAIT_TIME)
wait_until(predicate, "see expected nodelist", timeout=timeout)
def assert_nodelist_nochange(self, expected_nodelist, client):
def assert_nodelist_nochange(self, expected_nodelist, client, timeout=(100 * WAIT_TIME)):
"""Check if the client._topology ever deviates from seeing all nodes
in the expected_nodelist. Consistency is checked after sleeping for
(WAIT_TIME * 10) seconds. Also check that the resolver is called at
@ -136,7 +136,7 @@ class TestSrvPolling(unittest.TestCase):
return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1
return False
wait_until(predicate, "Node list equals expected nodelist", timeout=100 * WAIT_TIME)
wait_until(predicate, "Node list equals expected nodelist", timeout=timeout)
nodelist = self.get_nodelist(client)
if set(expected_nodelist) != set(nodelist):
@ -330,6 +330,22 @@ class TestSrvPolling(unittest.TestCase):
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
self.assert_nodelist_change(response, client)
def test_srv_waits_to_poll(self):
modified = [("localhost.test.build.10gen.cc", 27019)]
def resolver_response():
return modified
with SrvPollingKnobs(
ttl_time=WAIT_TIME,
min_srv_rescan_interval=WAIT_TIME,
nodelist_callback=resolver_response,
):
client = MongoClient(self.CONNECTION_STRING)
self.assertRaises(
AssertionError, self.assert_nodelist_change, modified, client, timeout=WAIT_TIME / 2
)
if __name__ == "__main__":
unittest.main()

View File

@ -172,6 +172,11 @@ elif OIDC_ENV == "azure":
"ENVIRONMENT": "azure",
"TOKEN_RESOURCE": os.environ["AZUREOIDC_RESOURCE"],
}
elif OIDC_ENV == "gcp":
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {
"ENVIRONMENT": "gcp",
"TOKEN_RESOURCE": os.environ["GCPOIDC_AUDIENCE"],
}
def interrupt_loop():