Compare commits

...

42 Commits

Author SHA1 Message Date
Noah Stapp
9a8e34c726
PYTHON-5366 - test_pool_reset waits until Pool.reset() increments gen… (#2797) 2026-05-18 10:29:44 -04:00
Noah Stapp
552b7bf47b
PYTHON-5631 - test_direct_client_maintains_pool_to_arbiter waits inst… (#2798) 2026-05-13 12:20:15 -04:00
Qi Deng
a50550535d
URL-encode client_id in Azure IMDS token request (#2787)
Co-authored-by: Qi Deng <qdeng@aurascape.ai>
2026-05-13 09:33:42 -04:00
Noah Stapp
0adf6df131
PYTHON-5708 - Unskip large encryption tests on mongocryptd (#2793) 2026-05-07 15:23:07 -04:00
Noah Stapp
f145c7db94
PYTHON-5756 - Fix BSON Binary type length bug (#2790) 2026-05-07 15:23:00 -04:00
Noah Stapp
b6bac45c7e
PYTHON-5032 - Use PyErr_GetRaisedException instead of deprecated PyEr… (#2795) 2026-05-07 14:52:19 -04:00
Noah Stapp
8dc7efade2
PYTHON-5821 - Fix ordering issue between event publish and logging for Pool monitoring tests (#2796) 2026-05-07 12:28:15 -04:00
Noah Stapp
f4219bdca2
PYTHON-5817 - Add "Project Structure and Asyncio Considerations" section to CONTRIBUTING.md (#2788)
Co-authored-by: Jib <Jibzade@gmail.com>
2026-05-06 13:28:36 -04:00
Noah Stapp
900d9c7910
PYTHON-5436 - Always include session on getMores if the initial curso… (#2794) 2026-05-06 13:10:13 -04:00
Noah Stapp
575d75f4d3
PYTHON-5813 - Skip QE prefixPreview and suffixPreview tests on server… (#2792) 2026-05-05 13:41:10 -04:00
Noah Stapp
c30eff1291
PYTHON-5811 - Change stream events are not emitted for timeseries as … (#2791) 2026-05-05 11:40:19 -04:00
Jeffrey 'Alex' Clark
e67931dff7
PYTHON-5776 Add documentation comments to justfile recipes (#2784) 2026-04-27 19:45:36 -04:00
mongodb-drivers-pr-bot[bot]
64edd22d73
[Spec Resync] 04-20-2026 (#2766)
Co-authored-by: Cloud User <ec2-user@ip-10-128-20-182.ec2.internal>
Co-authored-by: Jeffrey 'Alex' Clark <aclark@aclark.net>
2026-04-27 15:56:10 -04:00
Jeffrey 'Alex' Clark
b3f1c4befb
[Spec Resync] Remove stale spec patches for closed tickets (#2782) 2026-04-27 15:55:18 -04:00
Jeffrey 'Alex' Clark
ab44a21b46
PYTHON-5780 Increase code coverage for pyopenssl_context.py (#2773) 2026-04-24 09:04:02 -04:00
Jeffrey 'Alex' Clark
a13842f351
PYTHON-5778 Add 100% unit test coverage for event_loggers.py (#2769) 2026-04-21 12:36:48 -04:00
Jeffrey 'Alex' Clark
8363bf60ad
PYTHON-5774 Increase daemon.py coverage to 63% (#2759) 2026-04-20 16:52:36 -04:00
Jeffrey 'Alex' Clark
5406febcd9
Bump version to 4.18.0.dev0 (#2768) 2026-04-20 16:51:01 -04:00
Noah Stapp
3491c08ef6
PYTHON-5801 - Update changelog for 4.17 release (#2762) 2026-04-17 14:17:53 -04:00
Noah Stapp
912ef337f9
PYTHON-5798 - Overload retargeting prose tests do not ensure that sec… (#2760)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-16 13:32:50 -04:00
Noah Stapp
b4e2c03a92
PYTHON-5800 - Simple collation is included in index information (#2761) 2026-04-16 12:25:23 -04:00
Noah Stapp
f31ba09713
PYTHON-5797 - Add IWM and Overload Error links to changelog (#2757)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-15 14:42:29 -04:00
Noah Stapp
5da91837d4
PYTHON-5794 - Add prose tests to verify correct retry behavior when a… (#2755)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Jib <Jibzade@gmail.com>
2026-04-15 14:18:34 -04:00
Copilot
35e51a50f3
Revert "PYTHON-5768 Add AGENTS.md w/copilot instructions" (#2744) (#2754)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: aclark4life <72164+aclark4life@users.noreply.github.com>
Co-authored-by: Jib <jib.adegunloye@mongodb.com>
2026-04-15 12:59:12 -04:00
Jeffrey 'Alex' Clark
f41dd5c08b
PYTHON-5772 Increase _gcp_helpers.py coverage (#2749)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-14 16:53:35 -04:00
Jeffrey 'Alex' Clark
49e7a052e2
PYTHON-5760 Increase _azure_helpers.py coverage (#2747) 2026-04-14 16:24:51 -04:00
Jeffrey 'Alex' Clark
a2b0cd85e3
PYTHON-5795 Fix absolute link to CONTRIBUTING.md in README.md (#2756) 2026-04-14 15:48:00 -04:00
Noah Stapp
e1751ff253
PYTHON-5668 - Merge backpressure branch into mainline (#2729)
Co-authored-by: Steven Silvester <steve.silvester@mongodb.com>
Co-authored-by: Shane Harvey <shnhrv@gmail.com>
Co-authored-by: Steven Silvester <steven.silvester@ieee.org>
Co-authored-by: Iris <58442094+sleepyStick@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Kevin Albertson <kevin.albertson@mongodb.com>
Co-authored-by: Casey Clements <caseyclements@users.noreply.github.com>
Co-authored-by: Sergey Zelenov <mail@zelenov.su>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-14 12:25:29 -04:00
Noah Stapp
ee20ef52ec
PYTHON-5791 - test_list_database_names should not check ordering (#2751) 2026-04-13 14:01:14 -04:00
Jeffrey 'Alex' Clark
08b806fd87
PYTHON-5768 Add AGENTS.md w/copilot instructions (#2744)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-07 12:20:27 -04:00
Jib
db4db928d3
PYTHON-5401: Add AI Generated Contributions Policy (#2696)
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-01 11:51:53 -04:00
dependabot[bot]
ee851ba974
Bump astral-sh/setup-uv from 7.3.0 to 7.6.0 in the actions group (#2740)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-31 11:50:25 -07:00
mongodb-drivers-pr-bot[bot]
ce416a0944
[Spec Resync] 03-30-2026 (#2741)
Co-authored-by: Cloud User <ec2-user@ip-10-128-20-15.ec2.internal>
Co-authored-by: Iris Ho <iris.ho@mongodb.com>
2026-03-31 11:41:46 -07:00
dependabot[bot]
daba50c797
Bump the actions group across 1 directory with 4 updates (#2736)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 14:56:12 -04:00
Jeffrey 'Alex' Clark
c3428789fb
PYTHON-5766 Add codecov badge to readme (#2737)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-23 10:55:50 -04:00
Jeffrey 'Alex' Clark
ec9d95413c
PYTHON-5757 Deprecate Python 2 methods in SON (#2732) 2026-03-18 17:46:23 -04:00
Jeffrey 'Alex' Clark
13085ff679
PYTHON-5758 Remove unused validation functions (#2733) 2026-03-18 13:19:18 -04:00
Jeffrey 'Alex' Clark
80c3ff2aee
PYTHON-5753 Add just recipes for running coverage tests locally (#2727) 2026-03-12 12:42:15 -04:00
Jeffrey 'Alex' Clark
3d89d9faca
PYTHON-5754 Fix USE_ACTIVE_VENV support (#2728) 2026-03-11 14:09:11 -04:00
Shane Harvey
b6cc22ffdd
PYTHON-5748 Remove unused SpecRunner class (#2725) 2026-03-09 12:37:32 -07:00
Shane Harvey
f303125cee
PYTHON-5114 Test suite reduce killAllSessions calls (#2721) 2026-03-09 11:53:40 -07:00
Iris
38da6c3f9a
PYTHON-5747 Add jira link to spec resync PR (#2723) 2026-03-09 12:24:59 -04:00
96 changed files with 14482 additions and 1939 deletions

View File

@ -615,6 +615,7 @@ buildvariants:
- name: test-win64
tasks:
- name: .test-standard !.pypy
- name: .test-no-orchestration !.pypy
display_name: "* Test Win64"
run_on:
- windows-2022-latest-small

View File

@ -94,6 +94,9 @@ do
change-streams|change_streams)
cpjson change-streams/tests/ change_streams/
;;
client-backpressure|client_backpressure)
cpjson client-backpressure/tests client-backpressure
;;
client-side-encryption|csfle|fle)
cpjson client-side-encryption/tests/ client-side-encryption/spec
cpjson client-side-encryption/corpus/ client-side-encryption/corpus

View File

@ -97,6 +97,8 @@ def create_standard_nonlinux_variants() -> list[BuildVariant]:
tasks = [
f".test-standard !.pypy .server-{version}" for version in get_versions_from("6.0")
]
if host_name == "win64":
tasks.append(".test-no-orchestration !.pypy")
host = HOSTS[host_name]
tags = ["standard-non-linux"]
expansions = dict()

View File

@ -7,6 +7,8 @@ import subprocess
from argparse import Namespace
from subprocess import CalledProcessError
JIRA_FILTER = "https://jira.mongodb.org/issues/?jql=labels%20%3D%20automated-sync%20AND%20status%20!%3D%20Closed"
def resync_specs(directory: pathlib.Path, errored: dict[str, str]) -> None:
"""Actually sync the specs"""
@ -117,6 +119,7 @@ def write_summary(errored: dict[str, str], new: list[str], filename: str | None)
pr_body += "\n -".join(new)
pr_body += "\n"
if pr_body != "":
pr_body = f"Jira tickets: {JIRA_FILTER}\n\n" + pr_body
if filename is None:
print(f"\n{pr_body}")
else:

View File

@ -153,6 +153,10 @@ def handle_test_env() -> None:
# Start compiling the args we'll pass to uv.
UV_ARGS = ["--extra test --no-group dev"]
# If USE_ACTIVE_VENV is set, add --active to UV_ARGS so run-tests.sh uses the active venv.
if is_set("USE_ACTIVE_VENV"):
UV_ARGS.append("--active")
test_title = test_name
if sub_test_name:
test_title += f" {sub_test_name}"

View File

@ -1,64 +0,0 @@
diff --git a/test/load_balancer/cursors.json b/test/load_balancer/cursors.json
index 43e4fbb4f..4e2a55fd4 100644
--- a/test/load_balancer/cursors.json
+++ b/test/load_balancer/cursors.json
@@ -376,7 +376,7 @@
]
},
{
+ "description": "pinned connections are not returned after an network error during getMore",
- "description": "pinned connections are returned after an network error during getMore",
"operations": [
{
"name": "failPoint",
@@ -440,7 +440,7 @@
"object": "testRunner",
"arguments": {
"client": "client0",
+ "connections": 1
- "connections": 0
}
},
{
@@ -659,7 +659,7 @@
]
},
{
+ "description": "pinned connections are not returned to the pool after a non-network error on getMore",
- "description": "pinned connections are returned to the pool after a non-network error on getMore",
"operations": [
{
"name": "failPoint",
@@ -715,7 +715,7 @@
"object": "testRunner",
"arguments": {
"client": "client0",
+ "connections": 1
- "connections": 0
}
},
{
diff --git a/test/load_balancer/sdam-error-handling.json b/test/load_balancer/sdam-error-handling.json
index 63aabc04d..462fa0aac 100644
--- a/test/load_balancer/sdam-error-handling.json
+++ b/test/load_balancer/sdam-error-handling.json
@@ -366,6 +366,9 @@
{
"connectionCreatedEvent": {}
},
+ {
+ "poolClearedEvent": {}
+ },
{
"connectionClosedEvent": {
"reason": "error"
@@ -378,9 +375,6 @@
"connectionCheckOutFailedEvent": {
"reason": "connectionError"
}
- },
- {
- "poolClearedEvent": {}
}
]
}

View File

@ -1,14 +0,0 @@
diff --git a/test/discovery_and_monitoring/unified/serverMonitoringMode.json b/test/discovery_and_monitoring/unified/serverMonitoringMode.json
index e44fad1b..4b492f7d 100644
--- a/test/discovery_and_monitoring/unified/serverMonitoringMode.json
+++ b/test/discovery_and_monitoring/unified/serverMonitoringMode.json
@@ -5,7 +5,8 @@
{
"topologies": [
"single",
- "sharded"
+ "sharded",
+ "sharded-replicaset"
],
"serverless": "forbid"
}

View File

@ -1,61 +0,0 @@
diff --git a/test/server_selection_logging/replica-set.json b/test/server_selection_logging/replica-set.json
index 830b1ea51..5eba784bf 100644
--- a/test/server_selection_logging/replica-set.json
+++ b/test/server_selection_logging/replica-set.json
@@ -184,7 +184,7 @@
}
},
{
- "level": "debug",
+ "level": "info",
"component": "serverSelection",
"data": {
"message": "Waiting for suitable server to become available",
diff --git a/test/server_selection_logging/standalone.json b/test/server_selection_logging/standalone.json
index 830b1ea51..5eba784bf 100644
--- a/test/server_selection_logging/standalone.json
+++ b/test/server_selection_logging/standalone.json
@@ -191,7 +191,7 @@
}
},
{
- "level": "debug",
+ "level": "info",
"component": "serverSelection",
"data": {
"message": "Waiting for suitable server to become available",
diff --git a/test/server_selection_logging/sharded.json b/test/server_selection_logging/sharded.json
index 830b1ea51..5eba784bf 100644
--- a/test/server_selection_logging/sharded.json
+++ b/test/server_selection_logging/sharded.json
@@ -193,7 +193,7 @@
}
},
{
- "level": "debug",
+ "level": "info",
"component": "serverSelection",
"data": {
"message": "Waiting for suitable server to become available",
diff --git a/test/server_selection_logging/sharded.json b/test/server_selection_logging/operation-id.json
index 830b1ea51..5eba784bf 100644
--- a/test/server_selection_logging/operation-id.json
+++ b/test/server_selection_logging/operation-id.json
@@ -197,7 +197,7 @@
}
},
{
- "level": "debug",
+ "level": "info",
"component": "serverSelection",
"data": {
"message": "Waiting for suitable server to become available",
@@ -383,7 +383,7 @@
}
},
{
- "level": "debug",
+ "level": "info",
"component": "serverSelection",
"data": {
"message": "Waiting for suitable server to become available",

View File

@ -1,31 +0,0 @@
diff --git a/test/discovery_and_monitoring/errors/error_handling_handshake.json b/test/discovery_and_monitoring/errors/error_handling_handshake.json
index 56ca7d113..bf83f46f6 100644
--- a/test/discovery_and_monitoring/errors/error_handling_handshake.json
+++ b/test/discovery_and_monitoring/errors/error_handling_handshake.json
@@ -97,14 +97,22 @@
"outcome": {
"servers": {
"a:27017": {
- "type": "Unknown",
- "topologyVersion": null,
+ "type": "RSPrimary",
+ "setName": "rs",
+ "topologyVersion": {
+ "processId": {
+ "$oid": "000000000000000000000001"
+ },
+ "counter": {
+ "$numberLong": "1"
+ }
+ },
"pool": {
- "generation": 1
+ "generation": 0
}
}
},
- "topologyType": "ReplicaSetNoPrimary",
+ "topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs"
}

View File

@ -0,0 +1,460 @@
diff --git a/test/client-side-encryption/spec/unified/accessToken-azure.json b/test/client-side-encryption/spec/unified/accessToken-azure.json
new file mode 100644
index 00000000..510d8795
--- /dev/null
+++ b/test/client-side-encryption/spec/unified/accessToken-azure.json
@@ -0,0 +1,186 @@
+{
+ "description": "accessToken-azure",
+ "schemaVersion": "1.28",
+ "runOnRequirements": [
+ {
+ "minServerVersion": "4.1.10",
+ "csfle": {
+ "minLibmongocryptVersion": "1.6.0"
+ }
+ }
+ ],
+ "createEntities": [
+ {
+ "client": {
+ "id": "client",
+ "autoEncryptOpts": {
+ "keyVaultNamespace": "keyvault.datakeys",
+ "kmsProviders": {
+ "azure": {
+ "accessToken": {
+ "$$placeholder": 1
+ }
+ }
+ }
+ }
+ }
+ },
+ {
+ "database": {
+ "id": "db",
+ "client": "client",
+ "databaseName": "db"
+ }
+ },
+ {
+ "collection": {
+ "id": "coll",
+ "database": "db",
+ "collectionName": "coll"
+ }
+ },
+ {
+ "clientEncryption": {
+ "id": "clientEncryption",
+ "clientEncryptionOpts": {
+ "keyVaultClient": "client",
+ "keyVaultNamespace": "keyvault.datakeys",
+ "kmsProviders": {
+ "azure": {
+ "accessToken": {
+ "$$placeholder": 1
+ }
+ }
+ }
+ }
+ }
+ }
+ ],
+ "initialData": [
+ {
+ "databaseName": "db",
+ "collectionName": "coll",
+ "documents": [],
+ "createOptions": {
+ "validator": {
+ "$jsonSchema": {
+ "properties": {
+ "secret": {
+ "encrypt": {
+ "keyId": [
+ {
+ "$binary": {
+ "base64": "AZURE+AAAAAAAAAAAAAAAA==",
+ "subType": "04"
+ }
+ }
+ ],
+ "bsonType": "string",
+ "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"
+ }
+ }
+ },
+ "bsonType": "object"
+ }
+ }
+ }
+ },
+ {
+ "databaseName": "keyvault",
+ "collectionName": "datakeys",
+ "documents": [
+ {
+ "_id": {
+ "$binary": {
+ "base64": "AZURE+AAAAAAAAAAAAAAAA==",
+ "subType": "04"
+ }
+ },
+ "keyAltNames": [
+ "my-key"
+ ],
+ "keyMaterial": {
+ "$binary": {
+ "base64": "n+HWZ0ZSVOYA3cvQgP7inN4JSXfOH85IngmeQxRpQHjCCcqT3IFqEWNlrsVHiz3AELimHhX4HKqOLWMUeSIT6emUDDoQX9BAv8DR1+E1w4nGs/NyEneac78EYFkK3JysrFDOgl2ypCCTKAypkn9CkAx1if4cfgQE93LW4kczcyHdGiH36CIxrCDGv1UzAvERN5Qa47DVwsM6a+hWsF2AAAJVnF0wYLLJU07TuRHdMrrphPWXZsFgyV+lRqJ7DDpReKNO8nMPLV/mHqHBHGPGQiRdb9NoJo8CvokGz4+KE8oLwzKf6V24dtwZmRkrsDV4iOhvROAzz+Euo1ypSkL3mw==",
+ "subType": "00"
+ }
+ },
+ "creationDate": {
+ "$date": {
+ "$numberLong": "1552949630483"
+ }
+ },
+ "updateDate": {
+ "$date": {
+ "$numberLong": "1552949630483"
+ }
+ },
+ "status": {
+ "$numberInt": "0"
+ },
+ "masterKey": {
+ "provider": "azure",
+ "keyVaultEndpoint": "key-vault-csfle.vault.azure.net",
+ "keyName": "key-name-csfle"
+ }
+ }
+ ]
+ }
+ ],
+ "tests": [
+ {
+ "description": "Auto encrypt using access token Azure credentials",
+ "operations": [
+ {
+ "name": "insertOne",
+ "arguments": {
+ "document": {
+ "_id": 1,
+ "secret": "string0"
+ }
+ },
+ "object": "coll"
+ }
+ ],
+ "outcome": [
+ {
+ "documents": [
+ {
+ "_id": 1,
+ "secret": {
+ "$binary": {
+ "base64": "AQGVERPgAAAAAAAAAAAAAAAC5DbBSwPwfSlBrDtRuglvNvCXD1KzDuCKY2P+4bRFtHDjpTOE2XuytPAUaAbXf1orsPq59PVZmsbTZbt2CB8qaQ==",
+ "subType": "06"
+ }
+ }
+ }
+ ],
+ "collectionName": "coll",
+ "databaseName": "db"
+ }
+ ]
+ },
+ {
+ "description": "Explicit encrypt using access token Azure credentials",
+ "operations": [
+ {
+ "name": "encrypt",
+ "object": "clientEncryption",
+ "arguments": {
+ "value": "string0",
+ "opts": {
+ "keyAltName": "my-key",
+ "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"
+ }
+ },
+ "expectResult": {
+ "$binary": {
+ "base64": "AQGVERPgAAAAAAAAAAAAAAAC5DbBSwPwfSlBrDtRuglvNvCXD1KzDuCKY2P+4bRFtHDjpTOE2XuytPAUaAbXf1orsPq59PVZmsbTZbt2CB8qaQ==",
+ "subType": "06"
+ }
+ }
+ }
+ ]
+ }
+ ]
+}
diff --git a/test/client-side-encryption/spec/unified/accessToken-gcp.json b/test/client-side-encryption/spec/unified/accessToken-gcp.json
new file mode 100644
index 00000000..f5cf8914
--- /dev/null
+++ b/test/client-side-encryption/spec/unified/accessToken-gcp.json
@@ -0,0 +1,188 @@
+{
+ "description": "accessToken-gcp",
+ "schemaVersion": "1.28",
+ "runOnRequirements": [
+ {
+ "minServerVersion": "4.1.10",
+ "csfle": {
+ "minLibmongocryptVersion": "1.6.0"
+ }
+ }
+ ],
+ "createEntities": [
+ {
+ "client": {
+ "id": "client",
+ "autoEncryptOpts": {
+ "keyVaultNamespace": "keyvault.datakeys",
+ "kmsProviders": {
+ "gcp": {
+ "accessToken": {
+ "$$placeholder": 1
+ }
+ }
+ }
+ }
+ }
+ },
+ {
+ "database": {
+ "id": "db",
+ "client": "client",
+ "databaseName": "db"
+ }
+ },
+ {
+ "collection": {
+ "id": "coll",
+ "database": "db",
+ "collectionName": "coll"
+ }
+ },
+ {
+ "clientEncryption": {
+ "id": "clientEncryption",
+ "clientEncryptionOpts": {
+ "keyVaultClient": "client",
+ "keyVaultNamespace": "keyvault.datakeys",
+ "kmsProviders": {
+ "gcp": {
+ "accessToken": {
+ "$$placeholder": 1
+ }
+ }
+ }
+ }
+ }
+ }
+ ],
+ "initialData": [
+ {
+ "databaseName": "db",
+ "collectionName": "coll",
+ "documents": [],
+ "createOptions": {
+ "validator": {
+ "$jsonSchema": {
+ "properties": {
+ "secret": {
+ "encrypt": {
+ "keyId": [
+ {
+ "$binary": {
+ "base64": "GCP+AAAAAAAAAAAAAAAAAA==",
+ "subType": "04"
+ }
+ }
+ ],
+ "bsonType": "string",
+ "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"
+ }
+ }
+ },
+ "bsonType": "object"
+ }
+ }
+ }
+ },
+ {
+ "databaseName": "keyvault",
+ "collectionName": "datakeys",
+ "documents": [
+ {
+ "_id": {
+ "$binary": {
+ "base64": "GCP+AAAAAAAAAAAAAAAAAA==",
+ "subType": "04"
+ }
+ },
+ "keyAltNames": [
+ "my-key"
+ ],
+ "keyMaterial": {
+ "$binary": {
+ "base64": "CiQAIgLj0WyktnB4dfYHo5SLZ41K4ASQrjJUaSzl5vvVH0G12G0SiQEAjlV8XPlbnHDEDFbdTO4QIe8ER2/172U1ouLazG0ysDtFFIlSvWX5ZnZUrRMmp/R2aJkzLXEt/zf8Mn4Lfm+itnjgo5R9K4pmPNvvPKNZX5C16lrPT+aA+rd+zXFSmlMg3i5jnxvTdLHhg3G7Q/Uv1ZIJskKt95bzLoe0tUVzRWMYXLIEcohnQg==",
+ "subType": "00"
+ }
+ },
+ "creationDate": {
+ "$date": {
+ "$numberLong": "1552949630483"
+ }
+ },
+ "updateDate": {
+ "$date": {
+ "$numberLong": "1552949630483"
+ }
+ },
+ "status": {
+ "$numberInt": "0"
+ },
+ "masterKey": {
+ "provider": "gcp",
+ "projectId": "devprod-drivers",
+ "location": "global",
+ "keyRing": "key-ring-csfle",
+ "keyName": "key-name-csfle"
+ }
+ }
+ ]
+ }
+ ],
+ "tests": [
+ {
+ "description": "Auto encrypt using access token GCP credentials",
+ "operations": [
+ {
+ "name": "insertOne",
+ "arguments": {
+ "document": {
+ "_id": 1,
+ "secret": "string0"
+ }
+ },
+ "object": "coll"
+ }
+ ],
+ "outcome": [
+ {
+ "documents": [
+ {
+ "_id": 1,
+ "secret": {
+ "$binary": {
+ "base64": "ARgj/gAAAAAAAAAAAAAAAAACwFd+Y5Ojw45GUXNvbcIpN9YkRdoHDHkR4kssdn0tIMKlDQOLFkWFY9X07IRlXsxPD8DcTiKnl6XINK28vhcGlg==",
+ "subType": "06"
+ }
+ }
+ }
+ ],
+ "collectionName": "coll",
+ "databaseName": "db"
+ }
+ ]
+ },
+ {
+ "description": "Explicit encrypt using access token GCP credentials",
+ "operations": [
+ {
+ "name": "encrypt",
+ "object": "clientEncryption",
+ "arguments": {
+ "value": "string0",
+ "opts": {
+ "keyAltName": "my-key",
+ "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"
+ }
+ },
+ "expectResult": {
+ "$binary": {
+ "base64": "ARgj/gAAAAAAAAAAAAAAAAACwFd+Y5Ojw45GUXNvbcIpN9YkRdoHDHkR4kssdn0tIMKlDQOLFkWFY9X07IRlXsxPD8DcTiKnl6XINK28vhcGlg==",
+ "subType": "06"
+ }
+ }
+ }
+ ]
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-azure-accessToken-type.json b/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-azure-accessToken-type.json
new file mode 100644
index 00000000..8fe5c150
--- /dev/null
+++ b/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-azure-accessToken-type.json
@@ -0,0 +1,31 @@
+{
+ "description": "clientEncryptionOpts-kmsProviders-azure-accessToken-type",
+ "schemaVersion": "1.28",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ },
+ {
+ "clientEncryption": {
+ "id": "clientEncryption0",
+ "clientEncryptionOpts": {
+ "keyVaultClient": "client0",
+ "keyVaultNamespace": "keyvault.datakeys",
+ "kmsProviders": {
+ "azure": {
+ "accessToken": 0
+ }
+ }
+ }
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "",
+ "operations": []
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-gcp-accessToken-type.json b/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-gcp-accessToken-type.json
new file mode 100644
index 00000000..2284e26c
--- /dev/null
+++ b/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-gcp-accessToken-type.json
@@ -0,0 +1,31 @@
+{
+ "description": "clientEncryptionOpts-kmsProviders-gcp-accessToken-type",
+ "schemaVersion": "1.28",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ },
+ {
+ "clientEncryption": {
+ "id": "clientEncryption0",
+ "clientEncryptionOpts": {
+ "keyVaultClient": "client0",
+ "keyVaultNamespace": "keyvault.datakeys",
+ "kmsProviders": {
+ "gcp": {
+ "accessToken": 0
+ }
+ }
+ }
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "",
+ "operations": []
+ }
+ ]
+}

View File

@ -6,8 +6,8 @@ If you are an external contributor and there is no JIRA ticket associated with y
for the PR title. A MongoDB employee will create a JIRA ticket and edit the name and links as appropriate.
Note on AI Contributions:
We do not accept pull requests that are primarily or substantially generated by AI tools (ChatGPT, Copilot, etc.).
All contributions must be written and understood by human contributors.
We only accept pull requests that are authored and submitted by human contributors who fully understand the changes they are proposing.
All contributions must be written and understood by human contributors. Please read about our policy in our contributing guide.
-->
[JIRA TICKET]

View File

@ -61,7 +61,7 @@ jobs:
- name: Set up QEMU
if: runner.os == 'Linux'
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
with:
# setup-qemu-action by default uses `tonistiigi/binfmt:latest` image,
# which is out of date. This causes seg faults during build.
@ -92,7 +92,7 @@ jobs:
# Free-threading builds:
ls wheelhouse/*cp314t*.whl
- uses: actions/upload-artifact@v6
- uses: actions/upload-artifact@v7
with:
name: wheel-${{ matrix.buildplat[1] }}
path: ./wheelhouse/*.whl
@ -125,7 +125,7 @@ jobs:
cd ..
python -c "from pymongo import has_c; assert has_c()"
- uses: actions/upload-artifact@v6
- uses: actions/upload-artifact@v7
with:
name: "sdist"
path: ./dist/*.tar.gz
@ -136,13 +136,13 @@ jobs:
name: Download Wheels
steps:
- name: Download all workflow run artifacts
uses: actions/download-artifact@v7
uses: actions/download-artifact@v8
- name: Flatten directory
working-directory: .
run: |
find . -mindepth 2 -type f -exec mv {} . \;
find . -type d -empty -delete
- uses: actions/upload-artifact@v6
- uses: actions/upload-artifact@v7
with:
name: all-dist-${{ github.run_id }}
path: "./*"

View File

@ -75,7 +75,7 @@ jobs:
id-token: write
steps:
- name: Download all the dists
uses: actions/download-artifact@v7
uses: actions/download-artifact@v8
with:
name: all-dist-${{ github.run_id }}
path: dist/

View File

@ -67,7 +67,7 @@ jobs:
run: rm -rf .venv .venv-sbom sbom-requirements.txt
- name: Upload SBOM artifact
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v7
with:
name: sbom
path: sbom.json

View File

@ -26,7 +26,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: "3.10"
@ -68,7 +68,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -90,7 +90,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: "3.10"
@ -118,7 +118,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: "3.10"
@ -143,7 +143,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: "3.10"
@ -162,7 +162,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: "3.10"
@ -184,7 +184,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: "${{matrix.python}}"
@ -205,7 +205,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: "3.10"
@ -245,7 +245,7 @@ jobs:
run: |
pip install build
python -m build --sdist
- uses: actions/upload-artifact@v6
- uses: actions/upload-artifact@v7
with:
name: "sdist"
path: dist/*.tar.gz
@ -257,7 +257,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Download sdist
uses: actions/download-artifact@v7
uses: actions/download-artifact@v8
with:
path: sdist/
- name: Unpack SDist
@ -295,7 +295,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
python-version: "3.9"
- id: setup-mongodb

View File

@ -18,4 +18,4 @@ jobs:
with:
persist-credentials: false
- name: Run zizmor 🌈
uses: zizmorcore/zizmor-action@0dce2577a4760a2749d8cfb7a84b7d5585ebcb7d # v0.5.0
uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2

1
.gitignore vendored
View File

@ -43,3 +43,4 @@ test/lambda/*.json
xunit-results/
coverage.xml
server.log
.coverage

View File

@ -85,49 +85,53 @@ likelihood for getting review sooner shoots up.
- `versionadded:: 3.11`
- `versionchanged:: 3.5`
**Pull Request Template Breakdown**
### AI-Generated Contributions Policy
- **Github PR Title**
#### Our Stance
- The PR Title format should always be
`[JIRA-ID] : Jira Title or Blurb Summary`.
We only accept pull requests that are authored and submitted by human contributors who fully understand the changes they are proposing. Pull requests that are not clearly owned and understood by a human contributor may be closed. **All contributions must be submitted, reviewed, and understood by human contributors.**
- **JIRA LINK**
##### Why This Policy Exists
- Convenient link to the associated JIRA ticket.
At MongoDB, we understand the power and prevalence of AI tools in software development. With that being said, many MongoDB libraries are foundational tools used in production systems worldwide. The nature of these libraries requires:
- **Summary**
- **Deep domain expertise**: MongoDB's wire protocol, BSON specification, connection pooling, authentication mechanisms, and concurrency patterns require an understanding that AI alone cannot substantiate.
- Small blurb on why this is needed. The JIRA task should have
the more in-depth description, but this should still, at a
high level, give anyone looking an understanding of why the
PR has been checked in.
- **Long-term maintainability**: Contributors need to be able to explain *why* code is written a certain way, explain design decisions, and be available to iterate on their contributions.
- **Changes in this PR**
- **Security responsibility**: Authentication, credential handling, and TLS implementation cannot be left to probabilistic code generation.
- The explicit code changes that this PR is introducing. This
should be more specific than just the task name. (Unless the
task name is very clear).
##### What This Means for Contributors
- **Test Plan**
**Required:**
- Everything needs a test description. Describe what you did
to validate your changes actually worked; if you did
nothing, then document you did not test it. Aim to make
these steps reproducible by other engineers, specifically
with your primary reviewer in mind.
- Full understanding of every line of code you submit
- Ability to explain and defend your implementation choices
- Willingness to iterate and maintain your contributions
- **Screenshots**
**Encouraged:**
- Any images that provide more context to the PR. Usually,
these just coincide with the test plan.
- Using AI assistants as learning tools to understand concepts
- IDE autocomplete features that suggest standard patterns
- AI help for brainstorming approaches (but write the code yourself)
- Writing code using AI tools, reviewing each line and revising code as necessary.
- **Callouts or follow-up items**
**Not allowed:**
- This is a good place for identifying "to-dos" that you've
placed in the code (Must have an accompanying JIRA Ticket).
- Potential bugs that you are unsure how to test in the code.
- Opinions you want to receive about your code.
- Submitting PRs generated solely by AI tools
- Copy-pasting AI-generated code without full understanding
##### Disclosure
If you used AI assistance in any way during your contribution, please disclose what the AI assistant was used for in your PR description. We would love to know what tools developers have found useful in iterating in their day to day.
##### Questions?
If you're unsure whether your contribution complies with this policy, please ask for guidance within the scope of the PR and clarify any uncertainty. We're happy to guide contributors toward successful contributions.
---
*This policy helps us maintain the reliability, security, and trustworthiness that production applications depend on. Thank you for understanding and for contributing thoughtfully to PyMongo.*
## Running Linters
@ -205,6 +209,7 @@ the pages will re-render and the browser will automatically refresh.
and the `<class_name>` to test a full module. For example:
`just test test/test_change_stream.py::TestUnifiedChangeStreamsErrors::test_change_stream_errors_on_ElectionInProgress`.
- Use the `-k` argument to select tests by pattern.
- Run `just test-coverage` to run tests with coverage and display a report. After running tests with coverage, use `just coverage-html` to generate an HTML report in `htmlcov/index.html`.
## Running tests that require secrets, services, or other configuration
@ -500,13 +505,20 @@ python3 ./.evergreen/scripts/resync-all-specs.py
Follow the [Python Driver Release Process Wiki](https://wiki.corp.mongodb.com/display/DRIVERS/Python+Driver+Release+Process).
## Asyncio considerations
## Project Structure and Asyncio Considerations
PyMongo adds asyncio capability by modifying the source files in `*/asynchronous` to `*/synchronous` using
[unasync](https://github.com/python-trio/unasync/) and some custom transforms.
This section describes the layout of the `pymongo/` package.
Where possible, edit the code in `*/asynchronous/*.py` and not the synchronous files.
You can run `pre-commit run --all-files synchro` before running tests if you are testing synchronous code.
Within `pymongo/`, the code is further divided into the `pymongo/asynchronous` and `pymongo/synchronous` subdirectories.
Files in `pymongo/synchronous` are generated from `pymongo/asynchronous` using the `synchro` pre-commit hook, which uses [unasync](https://github.com/python-trio/unasync/) and some custom transforms.
As a result, **all modifications** within `pymongo` must be made in either the top-level `pymongo` directory when they have to exhibit differing behavior between sync and async contexts or the `pymongo/asynchronous` directory, not `pymongo/synchronous`.
Any changes made directly to files in the `pymongo/synchronous` directory will be overwritten by the `synchro` hook when it is run, which happens automatically on commit.
Some top-level files (e.g. `pymongo/collection.py`) are re-export files for existing import compatibility and should not be modified directly.
The other top-level files (e.g. `pymongo/network_layer.py`, `pymongo/pool_shared.py`) contain either shared code used in both the asynchronous and synchronous APIs, or code that is very different between the two APIs and therefore cannot be generated from the async version using `synchro`.
Run `pre-commit run --all-files synchro` before running tests to generate the latest version of the synchronous code.
To prevent the `synchro` hook from accidentally overwriting code, it first checks to see whether a sync version
of a file is changing and not its async counterpart, and will fail.

View File

@ -4,6 +4,7 @@
[![Python Versions](https://img.shields.io/pypi/pyversions/pymongo)](https://pypi.org/project/pymongo)
[![Monthly Downloads](https://static.pepy.tech/badge/pymongo/month)](https://pepy.tech/project/pymongo)
[![API Documentation Status](https://readthedocs.org/projects/pymongo/badge/?version=stable)](http://pymongo.readthedocs.io/en/stable/api?badge=stable)
[![codecov](https://codecov.io/gh/mongodb/mongo-python-driver/graph/badge.svg?branch=master)](https://codecov.io/gh/mongodb/mongo-python-driver)
## About
@ -215,4 +216,4 @@ pip install -e ".[test]"
pytest
```
For more advanced testing scenarios, see the [contributing guide](./CONTRIBUTING.md#running-tests-locally).
For more advanced testing scenarios, see the [contributing guide](https://github.com/mongodb/mongo-python-driver/blob/master/CONTRIBUTING.md#running-tests-locally).

View File

@ -109,6 +109,7 @@ struct module_state {
#define DATETIME_CLAMP 2
#define DATETIME_MS 3
#define DATETIME_AUTO 4
#define PYTHON_3_12 0x030C0000
/* Converts integer to its string representation in decimal notation. */
extern int cbson_long_long_to_str(long long num, char* str, size_t size) {
@ -249,6 +250,67 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
*/
static int write_raw_doc(buffer_t buffer, PyObject* raw, PyObject* _raw);
#if PY_VERSION_HEX >= PYTHON_3_12
/* Transfer traceback from old_exc to new_exc.
* Steals reference to old_exc. */
static PyObject* _transfer_traceback(PyObject *old_exc, PyObject *new_exc) {
PyObject *tb = PyException_GetTraceback(old_exc);
if (tb) {
PyException_SetTraceback(new_exc, tb);
Py_DECREF(tb);
}
Py_DECREF(old_exc);
return new_exc;
}
#endif
/* Rewrap the current exception as InvalidBSON(str(e)) if it is not already an InvalidBSON error. */
static void _rewrap_as_invalid_bson(void) {
#if PY_VERSION_HEX >= PYTHON_3_12
PyObject *exc = PyErr_GetRaisedException();
if (exc && PyErr_GivenExceptionMatches(exc, PyExc_Exception)) {
PyObject *InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) {
if (!PyErr_GivenExceptionMatches(exc, InvalidBSON)) {
PyObject *err_msg = PyObject_Str(exc);
if (err_msg) {
PyObject *new_exc = PyObject_CallOneArg(InvalidBSON, err_msg);
if (new_exc) {
exc = _transfer_traceback(exc, new_exc);
}
}
Py_XDECREF(err_msg);
}
Py_DECREF(InvalidBSON);
}
}
/* Steals reference to exc. */
PyErr_SetRaisedException(exc);
#else
PyObject *etype = NULL, *evalue = NULL, *etrace = NULL;
PyObject *InvalidBSON = NULL;
PyErr_Fetch(&etype, &evalue, &etrace);
if (PyErr_GivenExceptionMatches(etype, PyExc_Exception)) {
InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) {
if (!PyErr_GivenExceptionMatches(etype, InvalidBSON)) {
Py_DECREF(etype);
etype = InvalidBSON;
if (evalue) {
PyObject *msg = PyObject_Str(evalue);
Py_DECREF(evalue);
evalue = msg;
}
PyErr_NormalizeException(&etype, &evalue, &etrace);
} else {
Py_DECREF(InvalidBSON);
}
}
}
PyErr_Restore(etype, evalue, etrace);
#endif
}
/* Date stuff */
static PyObject* datetime_from_millis(long long millis) {
/* To encode a datetime instance like datetime(9999, 12, 31, 23, 59, 59, 999999)
@ -294,34 +356,57 @@ static PyObject* datetime_from_millis(long long millis) {
timeinfo.tm_sec,
microseconds);
if(!datetime) {
PyObject *etype = NULL, *evalue = NULL, *etrace = NULL;
#if PY_VERSION_HEX >= PYTHON_3_12
PyObject *exc = PyErr_GetRaisedException();
/*
* Calling _error clears the error state, so fetch it first.
*/
PyErr_Fetch(&etype, &evalue, &etrace);
/* Only add addition error message on ValueError exceptions. */
if (PyErr_GivenExceptionMatches(etype, PyExc_ValueError)) {
if (evalue) {
PyObject* err_msg = PyObject_Str(evalue);
/* Only add additional error message on ValueError exceptions. */
if (exc && PyErr_GivenExceptionMatches(exc, PyExc_ValueError)) {
PyObject* err_msg = PyObject_Str(exc);
if (err_msg) {
PyObject* appendage = PyUnicode_FromString(" (Consider Using CodecOptions(datetime_conversion=DATETIME_AUTO) or MongoClient(datetime_conversion='DATETIME_AUTO')). See: https://www.mongodb.com/docs/languages/python/pymongo-driver/current/data-formats/dates-and-times/#handling-out-of-range-datetimes");
if (appendage) {
PyObject* msg = PyUnicode_Concat(err_msg, appendage);
if (msg) {
Py_DECREF(evalue);
evalue = msg;
PyObject* new_exc = PyObject_CallOneArg(PyExc_ValueError, msg);
if (new_exc) {
exc = _transfer_traceback(exc, new_exc);
}
Py_DECREF(msg);
}
}
Py_XDECREF(appendage);
}
Py_XDECREF(err_msg);
}
PyErr_NormalizeException(&etype, &evalue, &etrace);
}
/* Steals references to args. */
PyErr_Restore(etype, evalue, etrace);
/* Steals reference to exc. */
PyErr_SetRaisedException(exc);
#else
/* Calling _error clears the error state, so fetch it first.*/
PyObject *etype = NULL, *evalue = NULL, *etrace = NULL;
PyErr_Fetch(&etype, &evalue, &etrace);
/* Only add additional error message on ValueError exceptions. */
if (PyErr_GivenExceptionMatches(etype, PyExc_ValueError)) {
if (evalue) {
PyObject* err_msg = PyObject_Str(evalue);
if (err_msg) {
PyObject* appendage = PyUnicode_FromString(" (Consider Using CodecOptions(datetime_conversion=DATETIME_AUTO) or MongoClient(datetime_conversion='DATETIME_AUTO')). See: https://www.mongodb.com/docs/languages/python/pymongo-driver/current/data-formats/dates-and-times/#handling-out-of-range-datetimes");
if (appendage) {
PyObject* msg = PyUnicode_Concat(err_msg, appendage);
if (msg) {
Py_DECREF(evalue);
evalue = msg;
}
}
Py_XDECREF(appendage);
}
Py_XDECREF(err_msg);
}
PyErr_NormalizeException(&etype, &evalue, &etrace);
}
/* Steals references to args. */
PyErr_Restore(etype, evalue, etrace);
#endif
}
return datetime;
}
@ -1681,6 +1766,46 @@ fail:
/* Update Invalid Document error to include doc as a property.
*/
void handle_invalid_doc_error(PyObject* dict) {
#if PY_VERSION_HEX >= PYTHON_3_12
PyObject *exc = PyErr_GetRaisedException();
PyObject *msg = NULL, *new_msg = NULL;
PyObject *InvalidDocument = NULL;
if (exc == NULL) {
return;
}
InvalidDocument = _error("InvalidDocument");
if (InvalidDocument == NULL) {
goto cleanup;
}
if (PyErr_GivenExceptionMatches(exc, InvalidDocument)) {
msg = PyObject_Str(exc);
if (msg) {
const char *msg_utf8 = PyUnicode_AsUTF8(msg);
if (msg_utf8 == NULL) {
goto cleanup;
}
new_msg = PyUnicode_FromFormat("Invalid document: %s", msg_utf8);
if (new_msg == NULL) {
goto cleanup;
}
/* Add doc to the error instance as a property. */
PyObject* exc_args[2] = {new_msg, dict};
PyObject* new_exc = PyObject_Vectorcall(InvalidDocument, exc_args, 2, NULL);
if (new_exc) {
exc = _transfer_traceback(exc, new_exc);
}
}
}
cleanup:
/* Steals reference to exc. */
PyErr_SetRaisedException(exc);
Py_XDECREF(msg);
Py_XDECREF(InvalidDocument);
Py_XDECREF(new_msg);
#else
PyObject *etype = NULL, *evalue = NULL, *etrace = NULL;
PyObject *msg = NULL, *new_msg = NULL, *new_evalue = NULL;
PyErr_Fetch(&etype, &evalue, &etrace);
@ -1723,6 +1848,7 @@ cleanup:
Py_XDECREF(InvalidDocument);
Py_XDECREF(new_evalue);
Py_XDECREF(new_msg);
#endif
}
@ -2155,7 +2281,7 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
}
memcpy(&length, buffer + *position, 4);
length = BSON_UINT32_FROM_LE(length);
if (max < length) {
if (max - 5 < length) { // Account for 5-byte header. max >= 5 guaranteed above
goto invalid;
}
@ -2654,42 +2780,7 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
* Wrap any non-InvalidBSON errors in InvalidBSON.
*/
if (PyErr_Occurred()) {
PyObject *etype = NULL, *evalue = NULL, *etrace = NULL;
PyObject *InvalidBSON = NULL;
/*
* Calling _error clears the error state, so fetch it first.
*/
PyErr_Fetch(&etype, &evalue, &etrace);
/* Dont reraise anything but PyExc_Exceptions as InvalidBSON. */
if (PyErr_GivenExceptionMatches(etype, PyExc_Exception)) {
InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) {
if (!PyErr_GivenExceptionMatches(etype, InvalidBSON)) {
/*
* Raise InvalidBSON(str(e)).
*/
Py_DECREF(etype);
etype = InvalidBSON;
if (evalue) {
PyObject *msg = PyObject_Str(evalue);
Py_DECREF(evalue);
evalue = msg;
}
PyErr_NormalizeException(&etype, &evalue, &etrace);
} else {
/*
* The current exception matches InvalidBSON, so we don't
* need this reference after all.
*/
Py_DECREF(InvalidBSON);
}
}
}
/* Steals references to args. */
PyErr_Restore(etype, evalue, etrace);
_rewrap_as_invalid_bson();
} else {
PyObject *InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) {
@ -2727,25 +2818,7 @@ static int _element_to_dict(PyObject* self, const char* string,
if (!*name) {
/* If NULL is returned then wrap the UnicodeDecodeError
in an InvalidBSON error */
PyObject *etype = NULL, *evalue = NULL, *etrace = NULL;
PyObject *InvalidBSON = NULL;
PyErr_Fetch(&etype, &evalue, &etrace);
if (PyErr_GivenExceptionMatches(etype, PyExc_Exception)) {
InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) {
Py_DECREF(etype);
etype = InvalidBSON;
if (evalue) {
PyObject *msg = PyObject_Str(evalue);
Py_DECREF(evalue);
evalue = msg;
}
PyErr_NormalizeException(&etype, &evalue, &etrace);
}
}
PyErr_Restore(etype, evalue, etrace);
_rewrap_as_invalid_bson();
return -1;
}
position += (unsigned)name_length + 1;

View File

@ -22,6 +22,7 @@ from __future__ import annotations
import copy
import re
import warnings
from collections.abc import Mapping as _Mapping
from typing import (
Any,
@ -99,13 +100,28 @@ class SON(Dict[_Key, _Value]):
yield from self.__keys
def has_key(self, key: _Key) -> bool:
warnings.warn(
"SON.has_key() is deprecated, use the in operator instead",
DeprecationWarning,
stacklevel=2,
)
return key in self.__keys
def iterkeys(self) -> Iterator[_Key]:
warnings.warn(
"SON.iterkeys() is deprecated, use the keys() method instead",
DeprecationWarning,
stacklevel=2,
)
return self.__iter__()
# fourth level uses definitions from lower levels
def itervalues(self) -> Iterator[_Value]:
warnings.warn(
"SON.itervalues() is deprecated, use the values() method instead",
DeprecationWarning,
stacklevel=2,
)
for _, v in self.items():
yield v

View File

@ -1,14 +1,21 @@
Changelog
=========
Changes in Version 4.17.0 (2026/XX/XX)
Changes in Version 4.17.0 (2026/04/20)
--------------------------------------
PyMongo 4.17 brings a number of changes including:
- ``has_key``, ``iterkeys`` and ``itervalues`` in :class:`bson.son.SON` have
been deprecated and will be removed in PyMongo 5.0. These methods were
deprecated in favor of the standard dictionary containment operator ``in``
and the ``keys()`` and ``values()`` methods, respectively.
- Added the :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.bind` and :meth:`~pymongo.client_session.ClientSession.bind` methods
that allow users to bind a session to all database operations within the scope of a context manager instead of having to explicitly pass the session to each individual operation.
See <PLACEHOLDER> for examples and more information.
See the `Transactions docs <https://www.mongodb.com/docs/languages/python/pymongo-driver/current/crud/transactions/#methods>`_ for examples and more information.
- Added support for MongoDB's Intelligent Workload Management (IWM) and ingress connection rate limiting features.
The driver now gracefully handles write-blocking scenarios and optimizes connection establishment during high-load conditions to maintain application availability.
See the `IWM <https://www.mongodb.com/docs/atlas/intelligent-workload-management>`_ or `Overload Errors <https://www.mongodb.com/docs/atlas/overload-errors/?interface=driver&language=python>`_ docs for more information.
Changes in Version 4.16.0 (2026/01/07)
--------------------------------------

View File

@ -16,62 +16,78 @@ default:
resync:
@uv sync --quiet
# Set up the development environment
install:
bash .evergreen/scripts/setup-dev-env.sh
# Build the HTML documentation
[group('docs')]
docs: && resync
{{docs_run}} sphinx-build -W -b html doc {{doc_build}}/html
# Serve the docs locally with live-reload
[group('docs')]
docs-serve: && resync
{{docs_run}} sphinx-autobuild -W -b html doc --watch ./pymongo --watch ./bson --watch ./gridfs {{doc_build}}/serve
# Check documentation hyperlinks for broken URLs
[group('docs')]
docs-linkcheck: && resync
{{docs_run}} sphinx-build -E -b linkcheck doc {{doc_build}}/linkcheck
# Run mypy and pyright
[group('typing')]
typing: && resync
just typing-mypy
just typing-pyright
# Run mypy against the library source and test suite
[group('typing')]
typing-mypy: && resync
{{typing_run}} python -m mypy {{mypy_args}} bson gridfs tools pymongo
{{typing_run}} python -m mypy {{mypy_args}} --config-file mypy_test.ini test
{{typing_run}} python -m mypy {{mypy_args}} test/test_typing.py test/test_typing_strict.py
# Run pyright against the typing test files
[group('typing')]
typing-pyright: && resync
{{typing_run}} python -m pyright test/test_typing.py test/test_typing_strict.py
{{typing_run}} python -m pyright -p strict_pyrightconfig.json test/test_typing_strict.py
# Run all pre-commit hooks across the repository
[group('lint')]
lint *args="": && resync
uvx pre-commit run --all-files {{args}}
# Run shellcheck, doc8, and slotscheck
[group('lint')]
lint-manual *args="": && resync
uvx pre-commit run --all-files --hook-stage manual {{args}}
# Run pytest (e.g. just test test/test_uri_parser.py)
[group('test')]
test *args="-v --durations=5 --maxfail=10": && resync
uv run --extra test python -m pytest {{args}}
#!/usr/bin/env bash
set -euo pipefail
uv run ${USE_ACTIVE_VENV:+--active} --extra test python -m pytest {{args}}
# Run the BSON test suite with numpy
[group('test')]
test-numpy *args="": && resync
just setup-tests numpy {{args}}
just run-tests test/test_bson.py
# Run tests via the Evergreen test runner script
[group('test')]
run-tests *args: && resync
bash ./.evergreen/run-tests.sh {{args}}
# Set up the test environment (auth, TLS, etc.)
[group('test')]
setup-tests *args="":
bash .evergreen/scripts/setup-tests.sh {{args}}
# Tear down resources created by setup-tests
[group('test')]
teardown-tests:
bash .evergreen/scripts/teardown-tests.sh
@ -80,6 +96,30 @@ teardown-tests:
integration-tests:
bash integration_tests/run.sh
# Run the full test suite with coverage
[group('test')]
test-coverage *args="":
just setup-tests --cov
just run-tests {{args}}
# Print the coverage summary to the terminal
[group('coverage')]
coverage-report:
uv tool run --with "coverage[toml]" coverage report
# Generate an HTML coverage report in htmlcov/
[group('coverage')]
coverage-html:
uv tool run --with "coverage[toml]" coverage html
@echo "Coverage report generated in htmlcov/index.html"
# Generate an XML coverage report at coverage.xml
[group('coverage')]
coverage-xml:
uv tool run --with "coverage[toml]" coverage xml
@echo "Coverage report generated in coverage.xml"
# Start a MongoDB server via drivers-evergreen-tools
[group('server')]
run-server *args="":
bash .evergreen/scripts/run-server.sh {{args}}

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import json
from typing import Any, Optional
from urllib.parse import quote
def _get_azure_response(
@ -29,7 +30,7 @@ def _get_azure_response(
url += "?api-version=2018-02-01"
url += f"&resource={resource}"
if client_id:
url += f"&client_id={client_id}"
url += f"&client_id={quote(client_id)}"
headers = {"Metadata": "true", "Accept": "application/json"}
request = Request(url, headers=headers) # noqa: S310
try:

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import re
from typing import List, Tuple, Union
__version__ = "4.17.0.dev0"
__version__ = "4.18.0.dev0"
def get_version_tuple(version: str) -> Tuple[Union[int, str], ...]:

View File

@ -59,6 +59,7 @@ from pymongo.errors import (
InvalidOperation,
NotPrimaryError,
OperationFailure,
PyMongoError,
WaitQueueTimeoutError,
)
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
@ -563,9 +564,17 @@ class _AsyncClientBulk:
error, ConnectionFailure
) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError))
retryable_label_error = isinstance(
error, PyMongoError
) and error.has_error_label("RetryableError")
# Synthesize the full bulk result without modifying the
# current one because this write operation may be retried.
if retryable and (retryable_top_level_error or retryable_network_error):
if retryable and (
retryable_top_level_error
or retryable_network_error
or retryable_label_error
):
full = copy.deepcopy(full_result)
_merge_command(self.ops, self.idx_offset, full, result)
_throw_client_bulk_write_exception(full, self.verbose_results)

View File

@ -135,7 +135,9 @@ Classes
from __future__ import annotations
import asyncio
import collections
import random
import time
import uuid
from collections.abc import Mapping as _Mapping
@ -162,7 +164,9 @@ from pymongo.asynchronous.cursor_base import _ConnectionManager
from pymongo.errors import (
ConfigurationError,
ConnectionFailure,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
OperationFailure,
PyMongoError,
WTimeoutError,
@ -427,6 +431,7 @@ class _Transaction:
self.recovery_token = None
self.attempt = 0
self.client = client
self.has_completed_command = False
def active(self) -> bool:
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
@ -434,6 +439,9 @@ class _Transaction:
def starting(self) -> bool:
return self.state == _TxnState.STARTING
def set_starting(self) -> None:
self.state = _TxnState.STARTING
@property
def pinned_conn(self) -> Optional[AsyncConnection]:
if self.active() and self.conn_mgr:
@ -459,6 +467,7 @@ class _Transaction:
self.sharded = False
self.recovery_token = None
self.attempt = 0
self.has_completed_command = False
def __del__(self) -> None:
if self.conn_mgr:
@ -493,11 +502,29 @@ _UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( #
# This limit is non-configurable and was chosen to be twice the 60 second
# default value of MongoDB's `transactionLifetimeLimitSeconds` parameter.
_WITH_TRANSACTION_RETRY_TIME_LIMIT = 120
_BACKOFF_MAX = 0.500 # 500ms max backoff
_BACKOFF_INITIAL = 0.005 # 5ms initial backoff
def _within_time_limit(start_time: float) -> bool:
def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
"""Are we within the with_transaction retry limit?"""
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
remaining = _csot.remaining()
if remaining is not None and remaining <= 0:
return False
return time.monotonic() + backoff - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
def _make_timeout_error(error: BaseException) -> PyMongoError:
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
if _csot.remaining() is not None:
timeout_error: PyMongoError = ExecutionTimeout(
str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}
)
else:
timeout_error = NetworkTimeout(str(error))
if isinstance(error, PyMongoError):
timeout_error._error_labels = error._error_labels.copy()
return timeout_error
_T = TypeVar("_T")
@ -744,7 +771,17 @@ class AsyncClientSession:
https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback
"""
start_time = time.monotonic()
retry = 0
last_error: Optional[BaseException] = None
while True:
if retry: # Implement exponential backoff on retry.
jitter = random.random() # noqa: S311
backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX)
if not _within_time_limit(start_time, backoff):
assert last_error is not None
raise _make_timeout_error(last_error) from last_error
await asyncio.sleep(backoff)
retry += 1
await self.start_transaction(
read_concern, write_concern, read_preference, max_commit_time_ms
)
@ -752,15 +789,16 @@ class AsyncClientSession:
ret = await callback(self)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as exc:
last_error = exc
if self.in_transaction:
await self.abort_transaction()
if (
isinstance(exc, PyMongoError)
and exc.has_error_label("TransientTransactionError")
and _within_time_limit(start_time)
if isinstance(exc, PyMongoError) and exc.has_error_label(
"TransientTransactionError"
):
# Retry the entire transaction.
continue
if _within_time_limit(start_time):
# Retry the entire transaction.
continue
raise _make_timeout_error(last_error) from exc
raise
if not self.in_transaction:
@ -771,17 +809,18 @@ class AsyncClientSession:
try:
await self.commit_transaction()
except PyMongoError as exc:
if (
exc.has_error_label("UnknownTransactionCommitResult")
and _within_time_limit(start_time)
and not _max_time_expired_error(exc)
):
last_error = exc
if exc.has_error_label(
"UnknownTransactionCommitResult"
) and not _max_time_expired_error(exc):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the commit.
continue
if exc.has_error_label("TransientTransactionError") and _within_time_limit(
start_time
):
if exc.has_error_label("TransientTransactionError"):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the entire transaction.
break
raise
@ -1062,7 +1101,11 @@ class AsyncClientSession:
read_preference: _ServerMode,
conn: AsyncConnection,
) -> None:
if not conn.supports_sessions:
# getMores must be sent with a session if the cursor was opened with one
operation = next(iter(command))
if not conn.supports_sessions and (
isinstance(self._server_session, _EmptyServerSession) or operation != "getMore"
):
if not self._implicit:
raise ConfigurationError("Sessions are not supported by this MongoDB deployment")
return

View File

@ -20,7 +20,6 @@ from collections import abc
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Callable,
Coroutine,
Generic,
@ -571,11 +570,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
await change_stream._initialize_cursor()
return change_stream
async def _conn_for_writes(
self, session: Optional[AsyncClientSession], operation: str
) -> AsyncContextManager[AsyncConnection]:
return await self._database.client._conn_for_writes(session, operation)
async def _command(
self,
conn: AsyncConnection,
@ -652,7 +646,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
if "size" in options:
options["size"] = float(options["size"])
cmd.update(options)
async with await self._conn_for_writes(session, operation=_Op.CREATE) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
if qev2_required and conn.max_wire_version < 21:
raise ConfigurationError(
"Driver support of Queryable Encryption is incompatible with server. "
@ -669,6 +666,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
session=session,
)
await self.database.client._retryable_write(False, inner, session, _Op.CREATE)
async def _create(
self,
options: MutableMapping[str, Any],
@ -2240,7 +2239,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
command (like maxTimeMS) can be passed as keyword arguments.
"""
names = []
async with await self._conn_for_writes(session, operation=_Op.CREATE_INDEXES) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> list[str]:
supports_quorum = conn.max_wire_version >= 9
def gen_indexes() -> Iterator[Mapping[str, Any]]:
@ -2269,7 +2271,11 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
write_concern=self._write_concern_for(session),
session=session,
)
return names
return names
return await self.database.client._retryable_write(
False, inner, session, _Op.CREATE_INDEXES
)
async def create_index(
self,
@ -2422,7 +2428,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
kwargs["comment"] = comment
await self._drop_index("*", session=session, **kwargs)
@_csot.apply
async def drop_index(
self,
index_or_name: _IndexKeyHint,
@ -2490,7 +2495,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.DROP_INDEXES) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command(
conn,
cmd,
@ -2500,6 +2508,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
session=session,
)
await self.database.client._retryable_write(False, inner, session, _Op.DROP_INDEXES)
async def list_indexes(
self,
session: Optional[AsyncClientSession] = None,
@ -2763,17 +2773,22 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd = {"createSearchIndexes": self.name, "indexes": list(gen_indexes())}
cmd.update(kwargs)
async with await self._conn_for_writes(
session, operation=_Op.CREATE_SEARCH_INDEXES
) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> list[str]:
resp = await self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
return [index["name"] for index in resp["indexesCreated"]]
return await self.database.client._retryable_write(
False, inner, session, _Op.CREATE_SEARCH_INDEXES
)
async def drop_search_index(
self,
name: str,
@ -2799,15 +2814,21 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.DROP_SEARCH_INDEXES) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
await self.database.client._retryable_write(False, inner, session, _Op.DROP_SEARCH_INDEXES)
async def update_search_index(
self,
name: str,
@ -2835,15 +2856,21 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.UPDATE_SEARCH_INDEX) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
await self.database.client._retryable_write(False, inner, session, _Op.UPDATE_SEARCH_INDEX)
async def options(
self,
session: Optional[AsyncClientSession] = None,
@ -2918,6 +2945,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
session,
retryable=not cmd._performs_write,
operation=_Op.AGGREGATE,
is_aggregate_write=cmd._performs_write,
)
async def aggregate(
@ -3123,17 +3151,21 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
cmd["comment"] = comment
write_concern = self._write_concern_for_cmd(cmd, session)
client = self._database.client
async with await self._conn_for_writes(session, operation=_Op.RENAME) as conn:
async with self._database.client._tmp_session(session) as s:
return await conn.command(
"admin",
cmd,
write_concern=write_concern,
parse_write_concern_error=True,
session=s,
client=self._database.client,
)
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> MutableMapping[str, Any]:
return await conn.command(
"admin",
cmd,
write_concern=write_concern,
parse_write_concern_error=True,
session=session,
client=client,
)
return await client._retryable_write(False, inner, session, _Op.RENAME)
async def distinct(
self,

View File

@ -931,14 +931,15 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
if read_preference is None:
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
async with await self._client._conn_for_reads(
read_preference, session, operation=command_name
) as (
connection,
read_preference,
):
async def inner(
session: Optional[AsyncClientSession],
_server: Server,
conn: AsyncConnection,
read_preference: _ServerMode,
) -> Union[dict[str, Any], _CodecDocumentType]:
return await self._command(
connection,
conn,
command,
value,
check,
@ -949,6 +950,10 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
**kwargs,
)
return await self._client._retryable_read(
inner, read_preference, session, command_name, None, False, is_run_command=True
)
@_csot.apply
async def cursor_command(
self,
@ -1016,17 +1021,17 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
async with self._client._tmp_session(session) as tmp_session:
opts = codec_options or DEFAULT_CODEC_OPTIONS
if read_preference is None:
read_preference = (
tmp_session and tmp_session._txn_read_preference()
) or ReadPreference.PRIMARY
async with await self._client._conn_for_reads(
read_preference, tmp_session, command_name
) as (
conn,
read_preference,
):
async def inner(
session: Optional[AsyncClientSession],
_server: Server,
conn: AsyncConnection,
read_preference: _ServerMode,
) -> AsyncCommandCursor[_DocumentType]:
response = await self._command(
conn,
command,
@ -1035,7 +1040,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
None,
read_preference,
opts,
session=tmp_session,
session=session,
**kwargs,
)
coll = self.get_collection("$cmd", read_preference=read_preference)
@ -1045,7 +1050,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
response["cursor"],
conn.address,
max_await_time_ms=max_await_time_ms,
session=tmp_session,
session=session,
comment=comment,
)
await cmd_cursor._maybe_pin_connection(conn)
@ -1053,6 +1058,10 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
else:
raise InvalidOperation("Command does not return a cursor.")
return await self.client._retryable_read(
inner, read_preference, tmp_session, command_name, None, False
)
async def _retryable_read_command(
self,
command: Union[str, MutableMapping[str, Any]],
@ -1254,9 +1263,11 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
command["comment"] = comment
async with await self._client._conn_for_writes(session, operation=_Op.DROP) as connection:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> dict[str, Any]:
return await self._command(
connection,
conn,
command,
allowable_errors=["ns not found", 26],
write_concern=self._write_concern_for(session),
@ -1264,6 +1275,8 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
session=session,
)
return await self.client._retryable_write(False, inner, session, _Op.DROP)
@_csot.apply
async def drop_collection(
self,

View File

@ -17,8 +17,11 @@ from __future__ import annotations
import asyncio
import builtins
import functools
import random
import socket
import sys
import time as time # noqa: PLC0414 # needed in sync version
from typing import (
Any,
Callable,
@ -26,6 +29,8 @@ from typing import (
cast,
)
from pymongo import _csot
from pymongo.common import MAX_ADAPTIVE_RETRIES
from pymongo.errors import (
OperationFailure,
)
@ -38,6 +43,7 @@ F = TypeVar("F", bound=Callable[..., Any])
def _handle_reauth(func: F) -> F:
@functools.wraps(func)
async def inner(*args: Any, **kwargs: Any) -> Any:
no_reauth = kwargs.pop("no_reauth", False)
from pymongo.asynchronous.pool import AsyncConnection
@ -70,6 +76,46 @@ def _handle_reauth(func: F) -> F:
return cast(F, inner)
_BACKOFF_INITIAL = 0.1
_BACKOFF_MAX = 10
def _backoff(
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
) -> float:
jitter = random.random() # noqa: S311
return jitter * min(initial_delay * (2**attempt), max_delay)
class _RetryPolicy:
"""A retry limiter that performs exponential backoff with jitter."""
def __init__(
self,
attempts: int = MAX_ADAPTIVE_RETRIES,
backoff_initial: float = _BACKOFF_INITIAL,
backoff_max: float = _BACKOFF_MAX,
):
self.attempts = attempts
self.backoff_initial = backoff_initial
self.backoff_max = backoff_max
def backoff(self, attempt: int) -> float:
"""Return the backoff duration for the given attempt."""
return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
async def should_retry(self, attempt: int, delay: float) -> bool:
"""Return if we have retry attempts remaining and the next backoff would not exceed a timeout."""
if attempt > self.attempts:
return False
if _csot.get_timeout():
if time.monotonic() + delay > _csot.get_deadline():
return False
return True
async def _getaddrinfo(
host: Any, port: Any, **kwargs: Any
) -> list[

View File

@ -35,6 +35,7 @@ from __future__ import annotations
import asyncio
import contextlib
import os
import time as time # noqa: PLC0414 # needed in sync version
import warnings
import weakref
from collections import defaultdict
@ -67,6 +68,9 @@ from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterCh
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.helpers import (
_RetryPolicy,
)
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
from pymongo.client_options import ClientOptions
@ -610,8 +614,18 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
client to use Stable API. See `versioned API <https://www.mongodb.com/docs/manual/reference/stable-api/#what-is-the-stable-api--and-should-you-use-it->`_ for
details.
| **Overload retry options:**
- `max_adaptive_retries`: (int) How many retries to allow for overload errors. Defaults to ``2``.
- `enable_overload_retargeting`: (boolean) Whether overload retargeting is enabled for this client.
If enabled, server overload errors will cause retry attempts to select a server that has not yet returned an overload error, if possible.
Defaults to ``False``.
.. seealso:: The MongoDB documentation on `connections <https://dochub.mongodb.org/core/connections>`_.
.. versionchanged:: 4.17
Added the ``max_adaptive_retries`` and ``enable_overload_retargeting`` URI and keyword arguments.
.. versionchanged:: 4.5
Added the ``serverMonitoringMode`` keyword argument.
@ -879,11 +893,14 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self._options.read_concern,
)
self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries)
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
self._opened = False
self._closed = False
self._loop: Optional[asyncio.AbstractEventLoop] = None
if not is_srv:
self._init_background()
@ -1991,6 +2008,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
read_pref: Optional[_ServerMode] = None,
retryable: bool = False,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
) -> T:
"""Internal retryable helper for all client transactions.
@ -2002,6 +2021,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: Server Address, defaults to None
:param read_pref: Topology of read operation, defaults to None
:param retryable: If the operation should be retried once, defaults to None
:param is_run_command: If this is a runCommand operation, defaults to False
:param is_aggregate_write: If this is a aggregate operation with a write, defaults to False.
:return: Output of the calling func()
"""
@ -2016,6 +2037,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address=address,
retryable=retryable,
operation_id=operation_id,
is_run_command=is_run_command,
is_aggregate_write=is_aggregate_write,
).run()
async def _retryable_read(
@ -2027,6 +2050,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address: Optional[_Address] = None,
retryable: bool = True,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
) -> T:
"""Execute an operation with consecutive retries if possible
@ -2042,6 +2067,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: Optional address when sending a message, defaults to None
:param retryable: if we should attempt retries
(may not always be supported even if supplied), defaults to False
:param is_run_command: If this is a runCommand operation, defaults to False.
:param is_aggregate_write: If this is a aggregate operation with a write, defaults to False.
"""
# Ensure that the client supports retrying on reads and there is no session in
@ -2060,6 +2087,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
read_pref=read_pref,
retryable=retryable,
operation_id=operation_id,
is_run_command=is_run_command,
is_aggregate_write=is_aggregate_write,
)
async def _retryable_write(
@ -2454,15 +2483,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
f"name_or_database must be an instance of str or a AsyncDatabase, not {type(name)}"
)
async with await self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn:
await self[name]._command(
conn,
{"dropDatabase": 1, "comment": comment},
read_preference=ReadPreference.PRIMARY,
write_concern=self._write_concern_for(session),
parse_write_concern_error=True,
session=session,
)
await self[name].command(
{"dropDatabase": 1, "comment": comment},
read_preference=ReadPreference.PRIMARY,
write_concern=self._write_concern_for(session),
parse_write_concern_error=True,
session=session,
)
@_csot.apply
async def bulk_write(
@ -2746,12 +2773,15 @@ class _ClientConnectionRetryable(Generic[T]):
address: Optional[_Address] = None,
retryable: bool = False,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
):
self._last_error: Optional[Exception] = None
self._retrying = False
self._multiple_retries = _csot.get_timeout() is not None
self._always_retryable = False
self._max_retries = float("inf") if _csot.get_timeout() is not None else 1
self._client = mongo_client
self._retry_policy = mongo_client._retry_policy
self._func = func
self._bulk = bulk
self._session = session
@ -2767,6 +2797,8 @@ class _ClientConnectionRetryable(Generic[T]):
self._operation = operation
self._operation_id = operation_id
self._attempt_number = 0
self._is_run_command = is_run_command
self._is_aggregate_write = is_aggregate_write
async def run(self) -> T:
"""Runs the supplied func() and attempts a retry
@ -2786,7 +2818,13 @@ class _ClientConnectionRetryable(Generic[T]):
while True:
self._check_last_error(check_csot=True)
try:
return await self._read() if self._is_read else await self._write()
res = await self._read() if self._is_read else await self._write()
# Track whether the transaction has completed a command.
# If we need to apply backpressure to the first command,
# we will need to revert back to starting state.
if self._session is not None and self._session.in_transaction:
self._session._transaction.has_completed_command = True
return res
except ServerSelectionTimeoutError:
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
@ -2797,37 +2835,80 @@ class _ClientConnectionRetryable(Generic[T]):
# most likely be a waste of time.
raise
except PyMongoError as exc:
always_retryable = False
overloaded = False
exc_to_check = exc
if self._is_run_command and not (
self._client.options.retry_reads and self._client.options.retry_writes
):
raise
if self._is_aggregate_write and not self._client.options.retry_writes:
raise
# Execute specialized catch on read
if self._is_read:
if isinstance(exc, (ConnectionFailure, OperationFailure)):
# ConnectionFailures do not supply a code property
exc_code = getattr(exc, "code", None)
if self._is_not_eligible_for_retry() or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
overloaded = exc.has_error_label("SystemOverloadedError")
if overloaded:
self._max_retries = self._client.options.max_adaptive_retries
always_retryable = exc.has_error_label("RetryableError") and overloaded
if not self._client.options.retry_reads or (
not always_retryable
and (
self._is_not_eligible_for_retry()
or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
)
)
):
raise
self._retrying = True
self._last_error = exc
self._attempt_number += 1
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
if (
overloaded
and self._session is not None
and self._session.in_transaction
):
transaction = self._session._transaction
if not transaction.has_completed_command:
transaction.set_starting()
transaction.attempt = 0
else:
raise
# Specialized catch on write operation
if not self._is_read:
if not self._retryable:
if isinstance(exc, ClientBulkWriteException) and isinstance(
exc.error, PyMongoError
):
exc_to_check = exc.error
retryable_write_label = exc_to_check.has_error_label("RetryableWriteError")
overloaded = exc_to_check.has_error_label("SystemOverloadedError")
if overloaded:
self._max_retries = self._client.options.max_adaptive_retries
always_retryable = exc_to_check.has_error_label("RetryableError") and overloaded
# Always retry abortTransaction and commitTransaction up to once
if self._operation not in ["abortTransaction", "commitTransaction"] and (
not self._client.options.retry_writes
or not (self._retryable or always_retryable)
):
raise
if isinstance(exc, ClientBulkWriteException) and exc.error:
retryable_write_error_exc = isinstance(
exc.error, PyMongoError
) and exc.error.has_error_label("RetryableWriteError")
else:
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
if retryable_write_error_exc:
if retryable_write_label or always_retryable:
assert self._session
await self._session._unpin()
if not retryable_write_error_exc or self._is_not_eligible_for_retry():
if exc.has_error_label("NoWritesPerformed") and self._last_error:
if not always_retryable and (
not retryable_write_label or self._is_not_eligible_for_retry()
):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
@ -2836,21 +2917,39 @@ class _ClientConnectionRetryable(Generic[T]):
self._bulk.retrying = True
else:
self._retrying = True
if not exc.has_error_label("NoWritesPerformed"):
if not exc_to_check.has_error_label("NoWritesPerformed"):
self._last_error = exc
if self._last_error is None:
self._last_error = exc
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
if overloaded and self._session is not None and self._session.in_transaction:
transaction = self._session._transaction
if not transaction.has_completed_command:
transaction.set_starting()
transaction.attempt = 0
if (
self._server is not None
and self._client.topology_description.topology_type_name == "Sharded"
or exc.has_error_label("SystemOverloadedError")
if self._server is not None and (
self._client.topology_description.topology_type_name == "Sharded"
or (overloaded and self._client.options.enable_overload_retargeting)
):
self._deprioritized_servers.append(self._server)
self._always_retryable = always_retryable
if overloaded:
delay = self._retry_policy.backoff(self._attempt_number)
if not await self._retry_policy.should_retry(self._attempt_number, delay):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
await asyncio.sleep(delay)
def _is_not_eligible_for_retry(self) -> bool:
"""Checks if the exchange is not eligible for retry"""
return not self._retryable or (self._is_retrying() and not self._multiple_retries)
return not self._retryable or (
self._is_retrying() and self._attempt_number >= self._max_retries
)
def _is_retrying(self) -> bool:
"""Checks if the exchange is currently undergoing a retry"""
@ -2909,7 +3008,7 @@ class _ClientConnectionRetryable(Generic[T]):
and conn.supports_sessions
)
is_mongos = conn.is_mongos
if not sessions_supported:
if not self._always_retryable and not sessions_supported:
# A retry is not possible because this server does
# not support sessions raise the last error.
self._check_last_error()
@ -2941,7 +3040,7 @@ class _ClientConnectionRetryable(Generic[T]):
conn,
read_pref,
):
if self._retrying and not self._retryable:
if self._retrying and not self._retryable and not self._always_retryable:
self._check_last_error()
if self._retrying:
_debug_log(

View File

@ -19,6 +19,8 @@ import collections
import contextlib
import logging
import os
import socket
import ssl
import sys
import time
import weakref
@ -52,10 +54,12 @@ from pymongo.errors import ( # type:ignore[attr-defined]
DocumentTooLarge,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
NotPrimaryError,
OperationFailure,
PyMongoError,
WaitQueueTimeoutError,
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.helpers_shared import _get_timeout_details, format_timeout_details
@ -250,6 +254,7 @@ class AsyncConnection:
cmd = self.hello_cmd()
performing_handshake = not self.performed_handshake
awaitable = False
cmd["backpressure"] = True
if performing_handshake:
self.performed_handshake = True
cmd["client"] = self.opts.metadata
@ -752,14 +757,10 @@ class Pool:
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = _async_create_condition(self.lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._max_connecting = self.opts.max_connecting
self._client_id = client_id
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_pool_created(
self.address, self.opts.non_default_options
)
# Log before publishing event to prevent potential listener preemption in tests
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_CONNECTION_LOGGER,
@ -769,6 +770,11 @@ class Pool:
serverPort=self.address[1],
**self.opts.non_default_options,
)
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_pool_created(
self.address, self.opts.non_default_options
)
# Similar to active_sockets but includes threads in the wait queue.
self.operation_count: int = 0
# Retain references to pinned connections to prevent the CPython GC
@ -783,9 +789,6 @@ class Pool:
async with self.lock:
if self.state != PoolState.READY:
self.state = PoolState.READY
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_pool_ready(self.address)
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_CONNECTION_LOGGER,
@ -794,6 +797,9 @@ class Pool:
serverHost=self.address[0],
serverPort=self.address[1],
)
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_pool_ready(self.address)
@property
def closed(self) -> bool:
@ -854,9 +860,6 @@ class Pool:
else:
for conn in sockets:
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_closed(self.address)
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_CONNECTION_LOGGER,
@ -865,15 +868,11 @@ class Pool:
serverHost=self.address[0],
serverPort=self.address[1],
)
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_closed(self.address)
else:
if old_state != PoolState.PAUSED:
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_cleared(
self.address,
service_id=service_id,
interrupt_connections=interrupt_connections,
)
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_CONNECTION_LOGGER,
@ -883,6 +882,13 @@ class Pool:
serverPort=self.address[1],
serviceId=service_id,
)
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_cleared(
self.address,
service_id=service_id,
interrupt_connections=interrupt_connections,
)
if not _IS_SYNC:
await asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value]
@ -986,6 +992,21 @@ class Pool:
self.requests -= 1
self.size_cond.notify()
def _handle_connection_error(self, error: BaseException) -> None:
# Handle system overload condition for non-sdam pools.
# Look for errors of type AutoReconnect and add error labels if appropriate.
if self.is_sdam or type(error) not in (AutoReconnect, NetworkTimeout):
return
assert isinstance(error, AutoReconnect) # Appease type checker.
# If the original error was a DNS, certificate, or SSL error, ignore it.
if isinstance(error.__cause__, (_CertificateError, SSLErrors, socket.gaierror)):
# End of file errors are excluded, because the server may have disconnected
# during the handshake.
if not isinstance(error.__cause__, (ssl.SSLEOFError, ssl.SSLZeroReturnError)):
return
error._add_error_label("SystemOverloadedError")
error._add_error_label("RetryableError")
async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> AsyncConnection:
"""Connect to Mongo and return a new AsyncConnection.
@ -1037,10 +1058,10 @@ class Pool:
reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR),
error=ConnectionClosedReason.ERROR,
)
self._handle_connection_error(error)
if isinstance(error, (IOError, OSError, *SSLErrors)):
details = _get_timeout_details(self.opts)
_raise_connection_failure(self.address, error, timeout_details=details)
raise
conn = AsyncConnection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
@ -1049,18 +1070,22 @@ class Pool:
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
await conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)
await conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
except BaseException as e:
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
await conn.close_conn(ConnectionClosedReason.ERROR)
raise
@ -1389,8 +1414,8 @@ class Pool:
:class:`~pymongo.errors.AutoReconnect` exceptions on server
hiccups, etc. We only check if the socket was closed by an external
error if it has been > 1 second since the socket was checked into the
pool, to keep performance reasonable - we can't avoid AutoReconnects
completely anyway.
pool to keep performance reasonable -
we can't avoid AutoReconnects completely anyway.
"""
idle_time_seconds = conn.idle_time_seconds()
# If socket is idle, open a new one.
@ -1401,8 +1426,9 @@ class Pool:
await conn.close_conn(ConnectionClosedReason.IDLE)
return True
if self._check_interval_seconds is not None and (
self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds
check_interval_seconds = self._check_interval_seconds
if check_interval_seconds is not None and (
check_interval_seconds == 0 or idle_time_seconds > check_interval_seconds
):
if conn.conn_closed():
await conn.close_conn(ConnectionClosedReason.ERROR)

View File

@ -913,7 +913,9 @@ class Topology:
# Clear the pool.
await server.reset(service_id)
elif isinstance(error, ConnectionFailure):
if isinstance(error, WaitQueueTimeoutError):
if isinstance(error, WaitQueueTimeoutError) or (
error.has_error_label("SystemOverloadedError")
):
return
# "Client MUST replace the server's description with type Unknown
# ... MUST NOT request an immediate check of the server."

View File

@ -235,6 +235,16 @@ class ClientOptions:
self.__server_monitoring_mode = options.get(
"servermonitoringmode", common.SERVER_MONITORING_MODE
)
self.__max_adaptive_retries = (
options.get("max_adaptive_retries", common.MAX_ADAPTIVE_RETRIES)
if "max_adaptive_retries" in options
else options.get("maxadaptiveretries", common.MAX_ADAPTIVE_RETRIES)
)
self.__enable_overload_retargeting = (
options.get("enable_overload_retargeting", common.ENABLE_OVERLOAD_RETARGETING)
if "enable_overload_retargeting" in options
else options.get("enableoverloadretargeting", common.ENABLE_OVERLOAD_RETARGETING)
)
@property
def _options(self) -> Mapping[str, Any]:
@ -346,3 +356,19 @@ class ClientOptions:
.. versionadded:: 4.5
"""
return self.__server_monitoring_mode
@property
def max_adaptive_retries(self) -> int:
"""The configured maxAdaptiveRetries option.
.. versionadded:: 4.17
"""
return self.__max_adaptive_retries
@property
def enable_overload_retargeting(self) -> bool:
"""The configured enableOverloadRetargeting option.
.. versionadded:: 4.17
"""
return self.__enable_overload_retargeting

View File

@ -140,6 +140,12 @@ SRV_SERVICE_NAME = "mongodb"
# Default value for serverMonitoringMode
SERVER_MONITORING_MODE = "auto" # poll/stream/auto
# Default value for max adaptive retries
MAX_ADAPTIVE_RETRIES = 2
# Default value for enableOverloadRetargeting
ENABLE_OVERLOAD_RETARGETING = False
# Auth mechanism properties that must raise an error instead of warning if they invalidate.
_MECH_PROP_MUST_RAISE = ["CANONICALIZE_HOST_NAME"]
@ -233,13 +239,6 @@ def validate_readable(option: str, value: Any) -> Optional[str]:
return value
def validate_positive_integer_or_none(option: str, value: Any) -> Optional[int]:
"""Validate that 'value' is a positive integer or None."""
if value is None:
return value
return validate_positive_integer(option, value)
def validate_non_negative_integer_or_none(option: str, value: Any) -> Optional[int]:
"""Validate that 'value' is a positive integer or 0 or None."""
if value is None:
@ -261,20 +260,6 @@ def validate_string_or_none(option: str, value: Any) -> Optional[str]:
return validate_string(option, value)
def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]:
"""Validates that 'value' is an integer or string."""
if isinstance(value, int):
return value
elif isinstance(value, str):
try:
return int(value)
except ValueError:
return value
raise TypeError(
f"Wrong type for {option}, value must be an integer or a string, not {type(value)}"
)
def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]:
"""Validates that 'value' is an integer or string."""
if isinstance(value, int):
@ -738,6 +723,8 @@ URI_OPTIONS_VALIDATOR_MAP: dict[str, Callable[[Any, Any], Any]] = {
"srvmaxhosts": validate_non_negative_integer,
"timeoutms": validate_timeoutms,
"servermonitoringmode": validate_server_monitoring_mode,
"maxadaptiveretries": validate_non_negative_integer,
"enableoverloadretargeting": validate_boolean_or_string,
}
# Dictionary where keys are the names of URI options specific to pymongo,
@ -771,6 +758,8 @@ KW_VALIDATORS: dict[str, Callable[[Any, Any], Any]] = {
"server_selector": validate_is_callable_or_none,
"auto_encryption_opts": validate_auto_encryption_opts_or_none,
"authoidcallowedhosts": validate_list,
"max_adaptive_retries": validate_non_negative_integer,
"enable_overload_retargeting": validate_boolean_or_string,
}
# Dictionary where keys are any URI option name, and values are the
@ -817,16 +806,6 @@ TIMEOUT_OPTIONS: list[str] = [
"waitqueuetimeoutms",
]
_AUTH_OPTIONS = frozenset(["authmechanismproperties"])
def validate_auth_option(option: str, value: Any) -> tuple[str, Any]:
"""Validate optional authentication parameters."""
lower, value = validate(option, value)
if lower not in _AUTH_OPTIONS:
raise ConfigurationError(f"Unknown option: {option}. Must be in {_AUTH_OPTIONS}")
return option, value
def _get_validator(
key: str, validators: dict[str, Callable[[Any, Any], Any]], normed_key: Optional[str] = None

View File

@ -59,6 +59,7 @@ from pymongo.errors import (
InvalidOperation,
NotPrimaryError,
OperationFailure,
PyMongoError,
WaitQueueTimeoutError,
)
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
@ -561,9 +562,17 @@ class _ClientBulk:
error, ConnectionFailure
) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError))
retryable_label_error = isinstance(
error, PyMongoError
) and error.has_error_label("RetryableError")
# Synthesize the full bulk result without modifying the
# current one because this write operation may be retried.
if retryable and (retryable_top_level_error or retryable_network_error):
if retryable and (
retryable_top_level_error
or retryable_network_error
or retryable_label_error
):
full = copy.deepcopy(full_result)
_merge_command(self.ops, self.idx_offset, full, result)
_throw_client_bulk_write_exception(full, self.verbose_results)

View File

@ -136,6 +136,7 @@ Classes
from __future__ import annotations
import collections
import random
import time
import uuid
from collections.abc import Mapping as _Mapping
@ -160,7 +161,9 @@ from pymongo import _csot
from pymongo.errors import (
ConfigurationError,
ConnectionFailure,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
OperationFailure,
PyMongoError,
WTimeoutError,
@ -426,6 +429,7 @@ class _Transaction:
self.recovery_token = None
self.attempt = 0
self.client = client
self.has_completed_command = False
def active(self) -> bool:
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
@ -433,6 +437,9 @@ class _Transaction:
def starting(self) -> bool:
return self.state == _TxnState.STARTING
def set_starting(self) -> None:
self.state = _TxnState.STARTING
@property
def pinned_conn(self) -> Optional[Connection]:
if self.active() and self.conn_mgr:
@ -458,6 +465,7 @@ class _Transaction:
self.sharded = False
self.recovery_token = None
self.attempt = 0
self.has_completed_command = False
def __del__(self) -> None:
if self.conn_mgr:
@ -492,11 +500,29 @@ _UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( #
# This limit is non-configurable and was chosen to be twice the 60 second
# default value of MongoDB's `transactionLifetimeLimitSeconds` parameter.
_WITH_TRANSACTION_RETRY_TIME_LIMIT = 120
_BACKOFF_MAX = 0.500 # 500ms max backoff
_BACKOFF_INITIAL = 0.005 # 5ms initial backoff
def _within_time_limit(start_time: float) -> bool:
def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
"""Are we within the with_transaction retry limit?"""
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
remaining = _csot.remaining()
if remaining is not None and remaining <= 0:
return False
return time.monotonic() + backoff - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
def _make_timeout_error(error: BaseException) -> PyMongoError:
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
if _csot.remaining() is not None:
timeout_error: PyMongoError = ExecutionTimeout(
str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}
)
else:
timeout_error = NetworkTimeout(str(error))
if isinstance(error, PyMongoError):
timeout_error._error_labels = error._error_labels.copy()
return timeout_error
_T = TypeVar("_T")
@ -743,21 +769,32 @@ class ClientSession:
https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback
"""
start_time = time.monotonic()
retry = 0
last_error: Optional[BaseException] = None
while True:
if retry: # Implement exponential backoff on retry.
jitter = random.random() # noqa: S311
backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX)
if not _within_time_limit(start_time, backoff):
assert last_error is not None
raise _make_timeout_error(last_error) from last_error
time.sleep(backoff)
retry += 1
self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms)
try:
ret = callback(self)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as exc:
last_error = exc
if self.in_transaction:
self.abort_transaction()
if (
isinstance(exc, PyMongoError)
and exc.has_error_label("TransientTransactionError")
and _within_time_limit(start_time)
if isinstance(exc, PyMongoError) and exc.has_error_label(
"TransientTransactionError"
):
# Retry the entire transaction.
continue
if _within_time_limit(start_time):
# Retry the entire transaction.
continue
raise _make_timeout_error(last_error) from exc
raise
if not self.in_transaction:
@ -768,17 +805,18 @@ class ClientSession:
try:
self.commit_transaction()
except PyMongoError as exc:
if (
exc.has_error_label("UnknownTransactionCommitResult")
and _within_time_limit(start_time)
and not _max_time_expired_error(exc)
):
last_error = exc
if exc.has_error_label(
"UnknownTransactionCommitResult"
) and not _max_time_expired_error(exc):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the commit.
continue
if exc.has_error_label("TransientTransactionError") and _within_time_limit(
start_time
):
if exc.has_error_label("TransientTransactionError"):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the entire transaction.
break
raise
@ -1059,7 +1097,11 @@ class ClientSession:
read_preference: _ServerMode,
conn: Connection,
) -> None:
if not conn.supports_sessions:
# getMores must be sent with a session if the cursor was opened with one
operation = next(iter(command))
if not conn.supports_sessions and (
isinstance(self._server_session, _EmptyServerSession) or operation != "getMore"
):
if not self._implicit:
raise ConfigurationError("Sessions are not supported by this MongoDB deployment")
return

View File

@ -21,7 +21,6 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Generic,
Iterable,
Iterator,
@ -572,11 +571,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
change_stream._initialize_cursor()
return change_stream
def _conn_for_writes(
self, session: Optional[ClientSession], operation: str
) -> ContextManager[Connection]:
return self._database.client._conn_for_writes(session, operation)
def _command(
self,
conn: Connection,
@ -653,7 +647,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
if "size" in options:
options["size"] = float(options["size"])
cmd.update(options)
with self._conn_for_writes(session, operation=_Op.CREATE) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> None:
if qev2_required and conn.max_wire_version < 21:
raise ConfigurationError(
"Driver support of Queryable Encryption is incompatible with server. "
@ -670,6 +667,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
session=session,
)
self.database.client._retryable_write(False, inner, session, _Op.CREATE)
def _create(
self,
options: MutableMapping[str, Any],
@ -2237,7 +2236,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
command (like maxTimeMS) can be passed as keyword arguments.
"""
names = []
with self._conn_for_writes(session, operation=_Op.CREATE_INDEXES) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> list[str]:
supports_quorum = conn.max_wire_version >= 9
def gen_indexes() -> Iterator[Mapping[str, Any]]:
@ -2266,7 +2268,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
write_concern=self._write_concern_for(session),
session=session,
)
return names
return names
return self.database.client._retryable_write(False, inner, session, _Op.CREATE_INDEXES)
def create_index(
self,
@ -2419,7 +2423,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
kwargs["comment"] = comment
self._drop_index("*", session=session, **kwargs)
@_csot.apply
def drop_index(
self,
index_or_name: _IndexKeyHint,
@ -2487,7 +2490,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
with self._conn_for_writes(session, operation=_Op.DROP_INDEXES) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> None:
self._command(
conn,
cmd,
@ -2497,6 +2503,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
session=session,
)
self.database.client._retryable_write(False, inner, session, _Op.DROP_INDEXES)
def list_indexes(
self,
session: Optional[ClientSession] = None,
@ -2760,15 +2768,22 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd = {"createSearchIndexes": self.name, "indexes": list(gen_indexes())}
cmd.update(kwargs)
with self._conn_for_writes(session, operation=_Op.CREATE_SEARCH_INDEXES) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> list[str]:
resp = self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
return [index["name"] for index in resp["indexesCreated"]]
return self.database.client._retryable_write(
False, inner, session, _Op.CREATE_SEARCH_INDEXES
)
def drop_search_index(
self,
name: str,
@ -2794,15 +2809,21 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
with self._conn_for_writes(session, operation=_Op.DROP_SEARCH_INDEXES) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> None:
self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
self.database.client._retryable_write(False, inner, session, _Op.DROP_SEARCH_INDEXES)
def update_search_index(
self,
name: str,
@ -2830,15 +2851,21 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
with self._conn_for_writes(session, operation=_Op.UPDATE_SEARCH_INDEX) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> None:
self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
self.database.client._retryable_write(False, inner, session, _Op.UPDATE_SEARCH_INDEX)
def options(
self,
session: Optional[ClientSession] = None,
@ -2911,6 +2938,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
session,
retryable=not cmd._performs_write,
operation=_Op.AGGREGATE,
is_aggregate_write=cmd._performs_write,
)
def aggregate(
@ -3116,17 +3144,21 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
cmd["comment"] = comment
write_concern = self._write_concern_for_cmd(cmd, session)
client = self._database.client
with self._conn_for_writes(session, operation=_Op.RENAME) as conn:
with self._database.client._tmp_session(session) as s:
return conn.command(
"admin",
cmd,
write_concern=write_concern,
parse_write_concern_error=True,
session=s,
client=self._database.client,
)
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> MutableMapping[str, Any]:
return conn.command(
"admin",
cmd,
write_concern=write_concern,
parse_write_concern_error=True,
session=session,
client=client,
)
return client._retryable_write(False, inner, session, _Op.RENAME)
def distinct(
self,

View File

@ -931,12 +931,15 @@ class Database(common.BaseObject, Generic[_DocumentType]):
if read_preference is None:
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
with self._client._conn_for_reads(read_preference, session, operation=command_name) as (
connection,
read_preference,
):
def inner(
session: Optional[ClientSession],
_server: Server,
conn: Connection,
read_preference: _ServerMode,
) -> Union[dict[str, Any], _CodecDocumentType]:
return self._command(
connection,
conn,
command,
value,
check,
@ -947,6 +950,10 @@ class Database(common.BaseObject, Generic[_DocumentType]):
**kwargs,
)
return self._client._retryable_read(
inner, read_preference, session, command_name, None, False, is_run_command=True
)
@_csot.apply
def cursor_command(
self,
@ -1014,15 +1021,17 @@ class Database(common.BaseObject, Generic[_DocumentType]):
with self._client._tmp_session(session) as tmp_session:
opts = codec_options or DEFAULT_CODEC_OPTIONS
if read_preference is None:
read_preference = (
tmp_session and tmp_session._txn_read_preference()
) or ReadPreference.PRIMARY
with self._client._conn_for_reads(read_preference, tmp_session, command_name) as (
conn,
read_preference,
):
def inner(
session: Optional[ClientSession],
_server: Server,
conn: Connection,
read_preference: _ServerMode,
) -> CommandCursor[_DocumentType]:
response = self._command(
conn,
command,
@ -1031,7 +1040,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
None,
read_preference,
opts,
session=tmp_session,
session=session,
**kwargs,
)
coll = self.get_collection("$cmd", read_preference=read_preference)
@ -1041,7 +1050,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
response["cursor"],
conn.address,
max_await_time_ms=max_await_time_ms,
session=tmp_session,
session=session,
comment=comment,
)
cmd_cursor._maybe_pin_connection(conn)
@ -1049,6 +1058,10 @@ class Database(common.BaseObject, Generic[_DocumentType]):
else:
raise InvalidOperation("Command does not return a cursor.")
return self.client._retryable_read(
inner, read_preference, tmp_session, command_name, None, False
)
def _retryable_read_command(
self,
command: Union[str, MutableMapping[str, Any]],
@ -1247,9 +1260,11 @@ class Database(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
command["comment"] = comment
with self._client._conn_for_writes(session, operation=_Op.DROP) as connection:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> dict[str, Any]:
return self._command(
connection,
conn,
command,
allowable_errors=["ns not found", 26],
write_concern=self._write_concern_for(session),
@ -1257,6 +1272,8 @@ class Database(common.BaseObject, Generic[_DocumentType]):
session=session,
)
return self.client._retryable_write(False, inner, session, _Op.DROP)
@_csot.apply
def drop_collection(
self,

View File

@ -17,8 +17,11 @@ from __future__ import annotations
import asyncio
import builtins
import functools
import random
import socket
import sys
import time as time # noqa: PLC0414 # needed in sync version
from typing import (
Any,
Callable,
@ -26,6 +29,8 @@ from typing import (
cast,
)
from pymongo import _csot
from pymongo.common import MAX_ADAPTIVE_RETRIES
from pymongo.errors import (
OperationFailure,
)
@ -38,6 +43,7 @@ F = TypeVar("F", bound=Callable[..., Any])
def _handle_reauth(func: F) -> F:
@functools.wraps(func)
def inner(*args: Any, **kwargs: Any) -> Any:
no_reauth = kwargs.pop("no_reauth", False)
from pymongo.message import _BulkWriteContext
@ -70,6 +76,46 @@ def _handle_reauth(func: F) -> F:
return cast(F, inner)
_BACKOFF_INITIAL = 0.1
_BACKOFF_MAX = 10
def _backoff(
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
) -> float:
jitter = random.random() # noqa: S311
return jitter * min(initial_delay * (2**attempt), max_delay)
class _RetryPolicy:
"""A retry limiter that performs exponential backoff with jitter."""
def __init__(
self,
attempts: int = MAX_ADAPTIVE_RETRIES,
backoff_initial: float = _BACKOFF_INITIAL,
backoff_max: float = _BACKOFF_MAX,
):
self.attempts = attempts
self.backoff_initial = backoff_initial
self.backoff_max = backoff_max
def backoff(self, attempt: int) -> float:
"""Return the backoff duration for the given attempt."""
return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
def should_retry(self, attempt: int, delay: float) -> bool:
"""Return if we have retry attempts remaining and the next backoff would not exceed a timeout."""
if attempt > self.attempts:
return False
if _csot.get_timeout():
if time.monotonic() + delay > _csot.get_deadline():
return False
return True
def _getaddrinfo(
host: Any, port: Any, **kwargs: Any
) -> list[

View File

@ -35,6 +35,7 @@ from __future__ import annotations
import asyncio
import contextlib
import os
import time as time # noqa: PLC0414 # needed in sync version
import warnings
import weakref
from collections import defaultdict
@ -110,6 +111,9 @@ from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.helpers import (
_RetryPolicy,
)
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology, _ErrorContext
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
@ -610,8 +614,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
client to use Stable API. See `versioned API <https://www.mongodb.com/docs/manual/reference/stable-api/#what-is-the-stable-api--and-should-you-use-it->`_ for
details.
| **Overload retry options:**
- `max_adaptive_retries`: (int) How many retries to allow for overload errors. Defaults to ``2``.
- `enable_overload_retargeting`: (boolean) Whether overload retargeting is enabled for this client.
If enabled, server overload errors will cause retry attempts to select a server that has not yet returned an overload error, if possible.
Defaults to ``False``.
.. seealso:: The MongoDB documentation on `connections <https://dochub.mongodb.org/core/connections>`_.
.. versionchanged:: 4.17
Added the ``max_adaptive_retries`` and ``enable_overload_retargeting`` URI and keyword arguments.
.. versionchanged:: 4.5
Added the ``serverMonitoringMode`` keyword argument.
@ -879,11 +893,14 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
self._options.read_concern,
)
self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries)
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
self._opened = False
self._closed = False
self._loop: Optional[asyncio.AbstractEventLoop] = None
if not is_srv:
self._init_background()
@ -1987,6 +2004,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
read_pref: Optional[_ServerMode] = None,
retryable: bool = False,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
) -> T:
"""Internal retryable helper for all client transactions.
@ -1998,6 +2017,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: Server Address, defaults to None
:param read_pref: Topology of read operation, defaults to None
:param retryable: If the operation should be retried once, defaults to None
:param is_run_command: If this is a runCommand operation, defaults to False
:param is_aggregate_write: If this is a aggregate operation with a write, defaults to False.
:return: Output of the calling func()
"""
@ -2012,6 +2033,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
address=address,
retryable=retryable,
operation_id=operation_id,
is_run_command=is_run_command,
is_aggregate_write=is_aggregate_write,
).run()
def _retryable_read(
@ -2023,6 +2046,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
address: Optional[_Address] = None,
retryable: bool = True,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
) -> T:
"""Execute an operation with consecutive retries if possible
@ -2038,6 +2063,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: Optional address when sending a message, defaults to None
:param retryable: if we should attempt retries
(may not always be supported even if supplied), defaults to False
:param is_run_command: If this is a runCommand operation, defaults to False.
:param is_aggregate_write: If this is a aggregate operation with a write, defaults to False.
"""
# Ensure that the client supports retrying on reads and there is no session in
@ -2056,6 +2083,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
read_pref=read_pref,
retryable=retryable,
operation_id=operation_id,
is_run_command=is_run_command,
is_aggregate_write=is_aggregate_write,
)
def _retryable_write(
@ -2444,15 +2473,13 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
f"name_or_database must be an instance of str or a Database, not {type(name)}"
)
with self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn:
self[name]._command(
conn,
{"dropDatabase": 1, "comment": comment},
read_preference=ReadPreference.PRIMARY,
write_concern=self._write_concern_for(session),
parse_write_concern_error=True,
session=session,
)
self[name].command(
{"dropDatabase": 1, "comment": comment},
read_preference=ReadPreference.PRIMARY,
write_concern=self._write_concern_for(session),
parse_write_concern_error=True,
session=session,
)
@_csot.apply
def bulk_write(
@ -2736,12 +2763,15 @@ class _ClientConnectionRetryable(Generic[T]):
address: Optional[_Address] = None,
retryable: bool = False,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
):
self._last_error: Optional[Exception] = None
self._retrying = False
self._multiple_retries = _csot.get_timeout() is not None
self._always_retryable = False
self._max_retries = float("inf") if _csot.get_timeout() is not None else 1
self._client = mongo_client
self._retry_policy = mongo_client._retry_policy
self._func = func
self._bulk = bulk
self._session = session
@ -2757,6 +2787,8 @@ class _ClientConnectionRetryable(Generic[T]):
self._operation = operation
self._operation_id = operation_id
self._attempt_number = 0
self._is_run_command = is_run_command
self._is_aggregate_write = is_aggregate_write
def run(self) -> T:
"""Runs the supplied func() and attempts a retry
@ -2776,7 +2808,13 @@ class _ClientConnectionRetryable(Generic[T]):
while True:
self._check_last_error(check_csot=True)
try:
return self._read() if self._is_read else self._write()
res = self._read() if self._is_read else self._write()
# Track whether the transaction has completed a command.
# If we need to apply backpressure to the first command,
# we will need to revert back to starting state.
if self._session is not None and self._session.in_transaction:
self._session._transaction.has_completed_command = True
return res
except ServerSelectionTimeoutError:
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
@ -2787,37 +2825,80 @@ class _ClientConnectionRetryable(Generic[T]):
# most likely be a waste of time.
raise
except PyMongoError as exc:
always_retryable = False
overloaded = False
exc_to_check = exc
if self._is_run_command and not (
self._client.options.retry_reads and self._client.options.retry_writes
):
raise
if self._is_aggregate_write and not self._client.options.retry_writes:
raise
# Execute specialized catch on read
if self._is_read:
if isinstance(exc, (ConnectionFailure, OperationFailure)):
# ConnectionFailures do not supply a code property
exc_code = getattr(exc, "code", None)
if self._is_not_eligible_for_retry() or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
overloaded = exc.has_error_label("SystemOverloadedError")
if overloaded:
self._max_retries = self._client.options.max_adaptive_retries
always_retryable = exc.has_error_label("RetryableError") and overloaded
if not self._client.options.retry_reads or (
not always_retryable
and (
self._is_not_eligible_for_retry()
or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
)
)
):
raise
self._retrying = True
self._last_error = exc
self._attempt_number += 1
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
if (
overloaded
and self._session is not None
and self._session.in_transaction
):
transaction = self._session._transaction
if not transaction.has_completed_command:
transaction.set_starting()
transaction.attempt = 0
else:
raise
# Specialized catch on write operation
if not self._is_read:
if not self._retryable:
if isinstance(exc, ClientBulkWriteException) and isinstance(
exc.error, PyMongoError
):
exc_to_check = exc.error
retryable_write_label = exc_to_check.has_error_label("RetryableWriteError")
overloaded = exc_to_check.has_error_label("SystemOverloadedError")
if overloaded:
self._max_retries = self._client.options.max_adaptive_retries
always_retryable = exc_to_check.has_error_label("RetryableError") and overloaded
# Always retry abortTransaction and commitTransaction up to once
if self._operation not in ["abortTransaction", "commitTransaction"] and (
not self._client.options.retry_writes
or not (self._retryable or always_retryable)
):
raise
if isinstance(exc, ClientBulkWriteException) and exc.error:
retryable_write_error_exc = isinstance(
exc.error, PyMongoError
) and exc.error.has_error_label("RetryableWriteError")
else:
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
if retryable_write_error_exc:
if retryable_write_label or always_retryable:
assert self._session
self._session._unpin()
if not retryable_write_error_exc or self._is_not_eligible_for_retry():
if exc.has_error_label("NoWritesPerformed") and self._last_error:
if not always_retryable and (
not retryable_write_label or self._is_not_eligible_for_retry()
):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
@ -2826,21 +2907,39 @@ class _ClientConnectionRetryable(Generic[T]):
self._bulk.retrying = True
else:
self._retrying = True
if not exc.has_error_label("NoWritesPerformed"):
if not exc_to_check.has_error_label("NoWritesPerformed"):
self._last_error = exc
if self._last_error is None:
self._last_error = exc
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
if overloaded and self._session is not None and self._session.in_transaction:
transaction = self._session._transaction
if not transaction.has_completed_command:
transaction.set_starting()
transaction.attempt = 0
if (
self._server is not None
and self._client.topology_description.topology_type_name == "Sharded"
or exc.has_error_label("SystemOverloadedError")
if self._server is not None and (
self._client.topology_description.topology_type_name == "Sharded"
or (overloaded and self._client.options.enable_overload_retargeting)
):
self._deprioritized_servers.append(self._server)
self._always_retryable = always_retryable
if overloaded:
delay = self._retry_policy.backoff(self._attempt_number)
if not self._retry_policy.should_retry(self._attempt_number, delay):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
time.sleep(delay)
def _is_not_eligible_for_retry(self) -> bool:
"""Checks if the exchange is not eligible for retry"""
return not self._retryable or (self._is_retrying() and not self._multiple_retries)
return not self._retryable or (
self._is_retrying() and self._attempt_number >= self._max_retries
)
def _is_retrying(self) -> bool:
"""Checks if the exchange is currently undergoing a retry"""
@ -2899,7 +2998,7 @@ class _ClientConnectionRetryable(Generic[T]):
and conn.supports_sessions
)
is_mongos = conn.is_mongos
if not sessions_supported:
if not self._always_retryable and not sessions_supported:
# A retry is not possible because this server does
# not support sessions raise the last error.
self._check_last_error()
@ -2931,7 +3030,7 @@ class _ClientConnectionRetryable(Generic[T]):
conn,
read_pref,
):
if self._retrying and not self._retryable:
if self._retrying and not self._retryable and not self._always_retryable:
self._check_last_error()
if self._retrying:
_debug_log(

View File

@ -19,6 +19,8 @@ import collections
import contextlib
import logging
import os
import socket
import ssl
import sys
import time
import weakref
@ -49,10 +51,12 @@ from pymongo.errors import ( # type:ignore[attr-defined]
DocumentTooLarge,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
NotPrimaryError,
OperationFailure,
PyMongoError,
WaitQueueTimeoutError,
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.helpers_shared import _get_timeout_details, format_timeout_details
@ -250,6 +254,7 @@ class Connection:
cmd = self.hello_cmd()
performing_handshake = not self.performed_handshake
awaitable = False
cmd["backpressure"] = True
if performing_handshake:
self.performed_handshake = True
cmd["client"] = self.opts.metadata
@ -750,14 +755,10 @@ class Pool:
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = _create_condition(self.lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._max_connecting = self.opts.max_connecting
self._client_id = client_id
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_pool_created(
self.address, self.opts.non_default_options
)
# Log before publishing event to prevent potential listener preemption in tests
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_CONNECTION_LOGGER,
@ -767,6 +768,11 @@ class Pool:
serverPort=self.address[1],
**self.opts.non_default_options,
)
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_pool_created(
self.address, self.opts.non_default_options
)
# Similar to active_sockets but includes threads in the wait queue.
self.operation_count: int = 0
# Retain references to pinned connections to prevent the CPython GC
@ -781,9 +787,6 @@ class Pool:
with self.lock:
if self.state != PoolState.READY:
self.state = PoolState.READY
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_pool_ready(self.address)
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_CONNECTION_LOGGER,
@ -792,6 +795,9 @@ class Pool:
serverHost=self.address[0],
serverPort=self.address[1],
)
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_pool_ready(self.address)
@property
def closed(self) -> bool:
@ -852,9 +858,6 @@ class Pool:
else:
for conn in sockets:
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_closed(self.address)
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_CONNECTION_LOGGER,
@ -863,15 +866,11 @@ class Pool:
serverHost=self.address[0],
serverPort=self.address[1],
)
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_closed(self.address)
else:
if old_state != PoolState.PAUSED:
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_cleared(
self.address,
service_id=service_id,
interrupt_connections=interrupt_connections,
)
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_CONNECTION_LOGGER,
@ -881,6 +880,13 @@ class Pool:
serverPort=self.address[1],
serviceId=service_id,
)
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_cleared(
self.address,
service_id=service_id,
interrupt_connections=interrupt_connections,
)
if not _IS_SYNC:
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value]
@ -982,6 +988,21 @@ class Pool:
self.requests -= 1
self.size_cond.notify()
def _handle_connection_error(self, error: BaseException) -> None:
# Handle system overload condition for non-sdam pools.
# Look for errors of type AutoReconnect and add error labels if appropriate.
if self.is_sdam or type(error) not in (AutoReconnect, NetworkTimeout):
return
assert isinstance(error, AutoReconnect) # Appease type checker.
# If the original error was a DNS, certificate, or SSL error, ignore it.
if isinstance(error.__cause__, (_CertificateError, SSLErrors, socket.gaierror)):
# End of file errors are excluded, because the server may have disconnected
# during the handshake.
if not isinstance(error.__cause__, (ssl.SSLEOFError, ssl.SSLZeroReturnError)):
return
error._add_error_label("SystemOverloadedError")
error._add_error_label("RetryableError")
def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection:
"""Connect to Mongo and return a new Connection.
@ -1033,10 +1054,10 @@ class Pool:
reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR),
error=ConnectionClosedReason.ERROR,
)
self._handle_connection_error(error)
if isinstance(error, (IOError, OSError, *SSLErrors)):
details = _get_timeout_details(self.opts)
_raise_connection_failure(self.address, error, timeout_details=details)
raise
conn = Connection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
@ -1045,18 +1066,22 @@ class Pool:
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)
conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
except BaseException as e:
with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
conn.close_conn(ConnectionClosedReason.ERROR)
raise
@ -1385,8 +1410,8 @@ class Pool:
:class:`~pymongo.errors.AutoReconnect` exceptions on server
hiccups, etc. We only check if the socket was closed by an external
error if it has been > 1 second since the socket was checked into the
pool, to keep performance reasonable - we can't avoid AutoReconnects
completely anyway.
pool to keep performance reasonable -
we can't avoid AutoReconnects completely anyway.
"""
idle_time_seconds = conn.idle_time_seconds()
# If socket is idle, open a new one.
@ -1397,8 +1422,9 @@ class Pool:
conn.close_conn(ConnectionClosedReason.IDLE)
return True
if self._check_interval_seconds is not None and (
self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds
check_interval_seconds = self._check_interval_seconds
if check_interval_seconds is not None and (
check_interval_seconds == 0 or idle_time_seconds > check_interval_seconds
):
if conn.conn_closed():
conn.close_conn(ConnectionClosedReason.ERROR)

View File

@ -911,7 +911,9 @@ class Topology:
# Clear the pool.
server.reset(service_id)
elif isinstance(error, ConnectionFailure):
if isinstance(error, WaitQueueTimeoutError):
if isinstance(error, WaitQueueTimeoutError) or (
error.has_error_label("SystemOverloadedError")
):
return
# "Client MUST replace the server's description with type Unknown
# ... MUST NOT request an immediate check of the server."

View File

@ -652,6 +652,38 @@ class AsyncClientUnitTest(AsyncUnitTest):
with self.assertWarns(UserWarning):
self.simple_client(multi_host)
async def test_max_adaptive_retries(self):
# Assert that max adaptive retries defaults to 2.
c = self.simple_client(connect=False)
self.assertEqual(c.options.max_adaptive_retries, 2)
# Assert that max adaptive retries can be configured through connection or client options.
c = self.simple_client(connect=False, max_adaptive_retries=10)
self.assertEqual(c.options.max_adaptive_retries, 10)
c = self.simple_client(connect=False, maxAdaptiveRetries=10)
self.assertEqual(c.options.max_adaptive_retries, 10)
c = self.simple_client(host="mongodb://localhost/?maxAdaptiveRetries=10", connect=False)
self.assertEqual(c.options.max_adaptive_retries, 10)
async def test_enable_overload_retargeting(self):
# Assert that overload retargeting defaults to false.
c = self.simple_client(connect=False)
self.assertFalse(c.options.enable_overload_retargeting)
# Assert that overload retargeting can be enabled through connection or client options.
c = self.simple_client(connect=False, enable_overload_retargeting=True)
self.assertTrue(c.options.enable_overload_retargeting)
c = self.simple_client(connect=False, enableOverloadRetargeting=True)
self.assertTrue(c.options.enable_overload_retargeting)
c = self.simple_client(
host="mongodb://localhost/?enableOverloadRetargeting=true", connect=False
)
self.assertTrue(c.options.enable_overload_retargeting)
class TestClient(AsyncIntegrationTest):
def test_multiple_uris(self):
@ -1034,7 +1066,7 @@ class TestClient(AsyncIntegrationTest):
db_names = await self.client.list_database_names()
self.assertIn("pymongo_test", db_names)
self.assertIn("pymongo_test_mike", db_names)
self.assertEqual(db_names, cmd_names)
self.assertCountEqual(db_names, cmd_names)
async def test_drop_database(self):
with self.assertRaises(TypeError):
@ -2679,11 +2711,11 @@ class TestClientPool(AsyncMockClientTest):
await async_wait_until(lambda: len(c.nodes) == 1, "connect")
self.assertEqual(await c.address, ("c", 3))
# Assert that we create 1 pooled connection.
# Wait for the pooled connection to be registered
await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 1)
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1)
arbiter = c._topology.get_server_by_address(("c", 3))
self.assertEqual(len(arbiter.pool.conns), 1)
await async_wait_until(lambda: len(arbiter.pool.conns) == 1, "create 1 pooled connection")
# Arbiter pool is marked ready.
self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 1)

View File

@ -0,0 +1,312 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Client Backpressure spec."""
from __future__ import annotations
import os
import pathlib
import sys
from time import perf_counter
from unittest.mock import patch
from pymongo.common import MAX_ADAPTIVE_RETRIES
sys.path[0:0] = [""]
from test.asynchronous import (
AsyncIntegrationTest,
async_client_context,
unittest,
)
from test.asynchronous.unified_format import generate_test_classes
from test.utils_shared import EventListener, OvertCommandListener
from pymongo.errors import OperationFailure, PyMongoError
_IS_SYNC = False
# Mock a system overload error.
mock_overload_error = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find", "insert", "update"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def get_mock_overload_error(times: int):
error = mock_overload_error.copy()
error["mode"] = {"times": times}
return error
class TestBackpressure(AsyncIntegrationTest):
RUN_ON_LOAD_BALANCER = True
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_command(self):
await self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
async with self.fail_point(fail_many):
await self.db.command("find", "t")
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.command("find", "t")
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_find(self):
await self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
async with self.fail_point(fail_many):
await self.db.t.find_one()
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.t.find_one()
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_insert_one(self):
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
async with self.fail_point(fail_many):
await self.db.t.insert_one({"x": 1})
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.t.insert_one({"x": 1})
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_update_many(self):
# Even though update_many is not a retryable write operation, it will
# still be retried via the "RetryableError" error label.
await self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
async with self.fail_point(fail_many):
await self.db.t.update_many({}, {"$set": {"x": 2}})
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.t.update_many({}, {"$set": {"x": 2}})
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_getMore(self):
coll = self.db.t
await coll.insert_many([{"x": 1} for _ in range(10)])
# Ensure command is retried on overload error.
fail_many = {
"configureFailPoint": "failCommand",
"mode": {"times": MAX_ADAPTIVE_RETRIES},
"data": {
"failCommands": ["getMore"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
cursor = coll.find(batch_size=2)
await cursor.next()
async with self.fail_point(fail_many):
await cursor.to_list()
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = fail_many.copy()
fail_too_many["mode"] = {"times": MAX_ADAPTIVE_RETRIES + 1}
cursor = coll.find(batch_size=2)
await cursor.next()
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await cursor.to_list()
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# Prose tests.
class AsyncTestClientBackpressure(AsyncIntegrationTest):
listener: EventListener
@classmethod
def setUpClass(cls) -> None:
cls.listener = OvertCommandListener()
@async_client_context.require_connection
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.listener.reset()
self.app_name = self.__class__.__name__.lower()
self.client = await self.async_rs_or_single_client(
event_listeners=[self.listener], appName=self.app_name
)
@patch("random.random")
@async_client_context.require_failCommand_appName
async def test_01_operation_retry_uses_exponential_backoff(self, random_func):
# Drivers should test that retries do not occur immediately when a SystemOverloadedError is encountered.
# 1. let `client` be a `MongoClient`
client = self.client
# 2. let `collection` be a collection
collection = client.test.test
# 3. Now, run transactions without backoff:
# a. Configure the random number generator used for jitter to always return `0` -- this effectively disables backoff.
random_func.return_value = 0
# b. Configure the following failPoint:
fail_point = dict(
mode="alwaysOn",
data=dict(
failCommands=["insert"],
errorCode=2,
errorLabels=["SystemOverloadedError", "RetryableError"],
appName=self.app_name,
),
)
async with self.fail_point(fail_point):
# c. Execute the following command. Expect that the command errors. Measure the duration of the command execution.
start0 = perf_counter()
with self.assertRaises(OperationFailure):
await collection.insert_one({"a": 1})
end0 = perf_counter()
# d. Configure the random number generator used for jitter to always return `1`.
random_func.return_value = 1
# e. Execute step c again.
start1 = perf_counter()
with self.assertRaises(OperationFailure):
await collection.insert_one({"a": 1})
end1 = perf_counter()
# f. Compare the times between the two runs.
# The sum of 2 backoffs is 0.3 seconds. There is a 0.3-second window to account for potential variance between the two
# runs.
self.assertTrue(abs((end1 - start1) - (end0 - start0 + 0.3)) < 0.3)
@async_client_context.require_failCommand_appName
async def test_03_overload_retries_limited(self):
# Drivers should test that overload errors are retried a maximum of two times.
# 1. Let `client` be a `MongoClient`.
client = self.client
# 2. Let `coll` be a collection.
coll = client.pymongo_test.coll
# 3. Configure the following failpoint:
failpoint = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
# 4. Perform a find operation with `coll` that fails.
async with self.fail_point(failpoint):
with self.assertRaises(PyMongoError) as error:
await coll.find_one({})
# 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels.
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# 6. Assert that the total number of started commands is MAX_ADAPTIVE_RETRIES + 1.
self.assertEqual(len(self.listener.started_events), MAX_ADAPTIVE_RETRIES + 1)
@async_client_context.require_failCommand_appName
async def test_04_overload_retries_limited_configured(self):
# Drivers should test that overload errors are retried a maximum of maxAdaptiveRetries times.
max_retries = 1
# 1. Let `client` be a `MongoClient` with `maxAdaptiveRetries=1` and command event monitoring enabled.
client = await self.async_single_client(
maxAdaptiveRetries=max_retries, event_listeners=[self.listener]
)
# 2. Let `coll` be a collection.
coll = client.pymongo_test.coll
# 3. Configure the following failpoint:
failpoint = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
# 4. Perform a find operation with `coll` that fails.
async with self.fail_point(failpoint):
with self.assertRaises(PyMongoError) as error:
await coll.find_one({})
# 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels.
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# 6. Assert that the total number of started commands is max_retries + 1.
self.assertEqual(len(self.listener.started_events), max_retries + 1)
# Location of JSON test specifications.
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "client-backpressure")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "client-backpressure")
globals().update(
generate_test_classes(
_TEST_PATH,
module=__name__,
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -219,6 +219,19 @@ class TestClientMetadataProse(AsyncIntegrationTest):
# add same metadata again
await self.check_metadata_added(client, "Framework", None, None)
async def test_handshake_documents_include_backpressure(self):
# Create a `MongoClient` that is configured to record all handshake documents sent to the server as a part of
# connection establishment.
client = await self.async_rs_or_single_client("mongodb://" + self.server.address_string)
# Send a `ping` command to the server and verify that the command succeeds. This ensure that a connection is
# established on all topologies. Note: MockupDB only supports standalone servers.
await client.admin.command("ping")
# Assert that for every handshake document intercepted:
# the document has a field `backpressure` whose value is `true`.
self.assertEqual(self.handshake_req["backpressure"], True)
if __name__ == "__main__":
unittest.main()

View File

@ -257,7 +257,6 @@ class TestCollation(AsyncIntegrationTest):
self.assertEqual(
ja_collation.document["locale"], indexes["japanese_version"]["collation"]["locale"]
)
self.assertNotIn("collation", indexes["simple"])
await self.db.test.drop_index("fieldname_1")
indexes = await self.db.test.index_information()
self.assertIn("japanese_version", indexes)

View File

@ -25,8 +25,10 @@ from asyncio import StreamReader, StreamWriter
from pathlib import Path
from test.asynchronous.helpers import ConcurrentRunner
from test.asynchronous.utils import flaky
from test.utils_shared import delay
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.errors import ConnectionFailure
from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector
@ -70,7 +72,12 @@ from pymongo.errors import (
)
from pymongo.hello import Hello, HelloCompat
from pymongo.helpers_shared import _check_command_response, _check_write_command_response
from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent
from pymongo.monitoring import (
ConnectionCheckOutFailedEvent,
PoolClearedEvent,
ServerHeartbeatFailedEvent,
ServerHeartbeatStartedEvent,
)
from pymongo.server_description import SERVER_TYPE, ServerDescription
from pymongo.topology_description import TOPOLOGY_TYPE
@ -131,6 +138,9 @@ async def got_app_error(topology, app_error):
raise AssertionError
except (AutoReconnect, NotPrimaryError, OperationFailure) as e:
if when == "beforeHandshakeCompletes":
# The pool would have added the SystemOverloadedError in this case.
if isinstance(e, AutoReconnect):
e._add_error_label("SystemOverloadedError")
completed_handshake = False
elif when == "afterHandshakeCompletes":
completed_handshake = True
@ -439,6 +449,59 @@ class TestPoolManagement(AsyncIntegrationTest):
AsyncConnection.close_conn = original_close
class TestPoolBackpressure(AsyncIntegrationTest):
@async_client_context.require_version_min(7, 0, 0)
async def test_connection_pool_is_not_cleared(self):
listener = CMAPListener()
# Create a client that listens to CMAP events, with maxConnecting=100.
client = await self.async_rs_or_single_client(maxConnecting=100, event_listeners=[listener])
# Enable the ingress rate limiter.
await client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=True
)
await client.admin.command("setParameter", 1, ingressConnectionEstablishmentRatePerSec=20)
await client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentBurstCapacitySecs=1
)
await client.admin.command("setParameter", 1, ingressConnectionEstablishmentMaxQueueDepth=1)
# Disable the ingress rate limiter on teardown.
# Sleep for 1 second before disabling to avoid the rate limiter.
async def teardown():
await asyncio.sleep(1)
await client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=False
)
self.addAsyncCleanup(teardown)
# Make sure the collection has at least one document.
await client.test.test.delete_many({})
await client.test.test.insert_one({})
# Run a slow operation to tie up the connection.
async def target():
try:
await client.test.test.find_one({"$where": delay(0.1)})
except ConnectionFailure:
pass
# Run 100 parallel operations that contend for connections.
tasks = []
for _ in range(100):
tasks.append(ConcurrentRunner(target=target))
for t in tasks:
await t.start()
for t in tasks:
await t.join()
# Verify there were at least 10 connection checkout failed event but no pool cleared events.
self.assertGreater(len(listener.events_by_type(ConnectionCheckOutFailedEvent)), 10)
self.assertEqual(len(listener.events_by_type(PoolClearedEvent)), 0)
class TestServerMonitoringMode(AsyncIntegrationTest):
@async_client_context.require_no_load_balancer
async def asyncSetUp(self):

View File

@ -876,8 +876,6 @@ class TestViews(AsyncEncryptionIntegrationTest):
class TestCorpus(AsyncEncryptionIntegrationTest):
# PYTHON-5708: Encryption tests sending large payloads fail on some mongocryptd versions.
@async_client_context.require_version_max(6, 99)
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
async def asyncSetUp(self):
await super().asyncSetUp()
@ -1054,8 +1052,6 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
client_encrypted: AsyncMongoClient
listener: OvertCommandListener
# PYTHON-5708: Encryption tests sending large payloads fail on some mongocryptd versions.
@async_client_context.require_version_max(6, 99)
async def asyncSetUp(self):
await super().asyncSetUp()
db = async_client_context.client.db
@ -3326,6 +3322,7 @@ class TestAutomaticDecryptionKeys(AsyncEncryptionIntegrationTest):
class TestExplicitTextEncryptionProse(AsyncEncryptionIntegrationTest):
@async_client_context.require_no_standalone
@async_client_context.require_version_min(8, 2, -1)
@async_client_context.require_version_max(8, 99, 99)
@async_client_context.require_libmongocrypt_min(1, 15, 1)
@async_client_context.require_pymongocrypt_min(1, 16, 0)
async def asyncSetUp(self):

View File

@ -513,6 +513,39 @@ class TestPooling(_TestPoolingBase):
str(error.exception),
)
@async_client_context.require_failCommand_appName
async def test_pool_backpressure_preserves_existing_connections(self):
client = await self.async_rs_or_single_client()
coll = client.pymongo_test.t
pool = await async_get_pool(client)
await coll.insert_many([{"x": 1} for _ in range(10)])
t = SocketGetter(self.c, pool)
await t.start()
while t.state != "connection":
await asyncio.sleep(0.1)
assert not t.sock.conn_closed()
# Mock a session establishment overload.
mock_connection_fail = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"closeConnection": True,
},
}
async with self.fail_point(mock_connection_fail):
await coll.find_one({})
# Make sure the existing socket was not affected.
assert not t.sock.conn_closed()
# Cleanup
await t.release_conn()
await t.join()
await pool.close()
class TestPoolMaxSize(_TestPoolingBase):
async def test_max_pool_size(self):

View File

@ -19,9 +19,12 @@ import os
import pprint
import sys
import threading
from test.asynchronous.utils import async_set_fail_point
from test.asynchronous.utils import async_ensure_all_connected, async_set_fail_point
from unittest import mock
from pymongo.errors import OperationFailure
from pymongo import MongoClient
from pymongo.common import MAX_ADAPTIVE_RETRIES
from pymongo.errors import OperationFailure, PyMongoError
sys.path[0:0] = [""]
@ -38,6 +41,7 @@ from test.utils_shared import (
)
from pymongo.monitoring import (
CommandFailedEvent,
ConnectionCheckedOutEvent,
ConnectionCheckOutFailedEvent,
ConnectionCheckOutFailedReason,
@ -145,6 +149,19 @@ class TestPoolPausedError(AsyncIntegrationTest):
class TestRetryableReads(AsyncIntegrationTest):
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.setup_client = MongoClient(**async_client_context.client_options)
self.addCleanup(self.setup_client.close)
# TODO: After PYTHON-4595 we can use async event handlers and remove this workaround.
def configure_fail_point_sync(self, command_args, off=False) -> None:
cmd = {"configureFailPoint": "failCommand", **command_args}
if off:
cmd["mode"] = "off"
cmd.pop("data", None)
self.setup_client.admin.command(cmd)
@async_client_context.require_multiple_mongoses
@async_client_context.require_failCommand_fail_point
async def test_retryable_reads_are_retried_on_a_different_mongos_when_one_is_available(self):
@ -265,16 +282,22 @@ class TestRetryableReads(AsyncIntegrationTest):
@async_client_context.require_secondaries_count(1)
@async_client_context.require_failCommand_fail_point
@async_client_context.require_version_min(4, 4, 0)
async def test_03_01_retryable_reads_caused_by_overload_errors_are_retried_on_a_different_replicaset_server_when_one_is_available(
async def test_03_01_retryable_reads_caused_by_overload_errors_are_retried_on_a_different_replicaset_server_when_one_is_available_and_overload_retargeting_is_enabled(
self
):
listener = OvertCommandListener()
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled.
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, `enableOverloadRetargeting=True`, and command event monitoring enabled.
client = await self.async_rs_or_single_client(
event_listeners=[listener], retryReads=True, readPreference="primaryPreferred"
event_listeners=[listener],
retryReads=True,
readPreference="primaryPreferred",
enableOverloadRetargeting=True,
)
# Ensure the client has discovered all nodes.
await async_ensure_all_connected(client)
# 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
@ -314,6 +337,9 @@ class TestRetryableReads(AsyncIntegrationTest):
event_listeners=[listener], retryReads=True, readPreference="primaryPreferred"
)
# Ensure the client has discovered all nodes.
await async_ensure_all_connected(client)
# 2. Configure a fail point with the RetryableError error label.
command_args = {
"configureFailPoint": "failCommand",
@ -339,6 +365,161 @@ class TestRetryableReads(AsyncIntegrationTest):
# 6. Assert that both events occurred the same server.
assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id
@async_client_context.require_replica_set
@async_client_context.require_secondaries_count(1)
@async_client_context.require_failCommand_fail_point
@async_client_context.require_version_min(4, 4, 0)
async def test_03_03_retryable_reads_caused_by_overload_errors_are_retried_on_the_same_replicaset_server_when_one_is_available_and_overload_retargeting_is_disabled(
self
):
listener = OvertCommandListener()
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled.
client = await self.async_rs_or_single_client(
event_listeners=[listener],
retryReads=True,
readPreference="primaryPreferred",
)
# Ensure the client has discovered all nodes.
await async_ensure_all_connected(client)
# 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 6,
},
}
await async_set_fail_point(client, command_args)
# 3. Reset the command event monitor to clear the fail point command from its stored events.
listener.reset()
# 4. Execute a `find` command with `client`.
await client.t.t.find_one({})
# 5. Assert that one failed command event and one successful command event occurred.
self.assertEqual(len(listener.failed_events), 1)
self.assertEqual(len(listener.succeeded_events), 1)
# 6. Assert that both events occurred on the same server.
assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id
@async_client_context.require_failCommand_fail_point
@async_client_context.require_version_min(4, 4, 0) # type:ignore[untyped-decorator]
async def test_overload_then_nonoverload_retries_increased_reads(self) -> None:
# Create a client.
listener = OvertCommandListener()
# Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error
# code `91` (ShutdownInProgress) and `RetryableError` and `SystemOverloadedError` labels.
overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 91,
},
}
# Configure a fail point with error code `91` (ShutdownInProgress) with only the `RetryableError` error label.
non_overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 91,
"errorLabels": ["RetryableError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the fail point command only if the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(non_overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = await self.async_rs_client(event_listeners=[listener])
await client.test.test.insert_one({})
self.configure_fail_point_sync(overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
with self.assertRaises(PyMongoError):
await client.test.test.find_one()
started_finds = [e for e in listener.started_events if e.command_name == "find"]
self.assertEqual(len(started_finds), MAX_ADAPTIVE_RETRIES + 1)
@async_client_context.require_failCommand_fail_point
@async_client_context.require_version_min(4, 4, 0) # type:ignore[untyped-decorator]
async def test_backoff_is_not_applied_for_non_overload_errors(self):
if _IS_SYNC:
mock_target = "pymongo.synchronous.helpers._RetryPolicy.backoff"
else:
mock_target = "pymongo.asynchronous.helpers._RetryPolicy.backoff"
# Create a client.
listener = OvertCommandListener()
# Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error
# code `91` (ShutdownInProgress) and `RetryableError` and `SystemOverloadedError` labels.
overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 91,
},
}
# Configure a fail point with error code `91` (ShutdownInProgress) with only the `RetryableError` error label.
non_overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 91,
"errorLabels": ["RetryableError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the fail point command only if the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(non_overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = await self.async_rs_client(event_listeners=[listener])
await client.test.test.insert_one({})
self.configure_fail_point_sync(overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Perform a findOne operation with coll. Expect the operation to fail.
with mock.patch(mock_target, return_value=0) as mock_backoff:
with self.assertRaises(PyMongoError):
await client.test.test.find_one()
# Assert that backoff was applied only once for the initial overload error and not for the subsequent non-overload retryable errors.
self.assertEqual(mock_backoff.call_count, 1)
if __name__ == "__main__":
unittest.main()

View File

@ -21,6 +21,9 @@ import pprint
import sys
import threading
from test.asynchronous.utils import async_set_fail_point, flaky
from unittest import mock
from pymongo.common import MAX_ADAPTIVE_RETRIES
sys.path[0:0] = [""]
@ -43,14 +46,17 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS
from bson.int64 import Int64
from bson.raw_bson import RawBSONDocument
from bson.son import SON
from pymongo import MongoClient
from pymongo.errors import (
AutoReconnect,
ConnectionFailure,
OperationFailure,
NotPrimaryError,
PyMongoError,
ServerSelectionTimeoutError,
WriteConcernError,
)
from pymongo.monitoring import (
CommandFailedEvent,
CommandSucceededEvent,
ConnectionCheckedOutEvent,
ConnectionCheckOutFailedEvent,
@ -601,5 +607,291 @@ class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest):
self.assertEqual(sent_txn_id, final_txn_id, msg)
class TestErrorPropagationAfterEncounteringMultipleErrors(AsyncIntegrationTest):
# Only run against replica sets as mongos does not propagate the NoWritesPerformed label to the drivers.
@async_client_context.require_replica_set
# Run against server versions 6.0 and above.
@async_client_context.require_version_min(6, 0) # type: ignore[untyped-decorator]
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.setup_client = MongoClient(**async_client_context.default_client_options)
self.addCleanup(self.setup_client.close)
# TODO: After PYTHON-4595 we can use async event handlers and remove this workaround.
def configure_fail_point_sync(self, command_args, off=False) -> None:
cmd = {"configureFailPoint": "failCommand"}
cmd.update(command_args)
if off:
cmd["mode"] = "off"
cmd.pop("data", None)
self.setup_client.admin.command(cmd)
async def test_01_drivers_return_the_correct_error_when_receiving_only_errors_without_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 91,
},
}
# Via the command monitoring CommandFailedEvent, configure a fail point with error code 10107 (NotWritablePrimary).
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 10107,
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = await self.async_rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(NotPrimaryError) as exc:
await client.test.test.insert_one({})
# Assert that the error code of the server error is 10107.
assert exc.exception.errors["code"] == 10107 # type:ignore[call-overload]
async def test_02_drivers_return_the_correct_error_when_receiving_only_errors_with_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
"errorCode": 91,
},
}
# Via the command monitoring CommandFailedEvent, configure a fail point with error code `10107` (NotWritablePrimary)
# and a NoWritesPerformed label.
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 10107,
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
},
}
def failed(event: CommandFailedEvent) -> None:
if listener.failed_events:
return
# Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2.
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = await self.async_rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(NotPrimaryError) as exc:
await client.test.test.insert_one({})
# Assert that the error code of the server error is 91.
assert exc.exception.errors["code"] == 91 # type:ignore[call-overload]
async def test_03_drivers_return_the_correct_error_when_receiving_some_errors_with_NoWritesPerformed_and_some_without_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error
# code `91` (NotWritablePrimary) and the `NoWritesPerformed`, `RetryableError` and `SystemOverloadedError` labels.
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
"errorCode": 91,
},
}
# Configure a fail point with error code `91` (ShutdownInProgress) with the `RetryableError` and
# `SystemOverloadedError` error labels but without the `NoWritesPerformed` error label.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorCode": 91,
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the fail point command only if the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = await self.async_rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(PyMongoError) as exc:
await client.test.test.insert_one({})
# Assert that the error code of the server error is 91.
assert exc.exception.errors["code"] == 91
# Assert that the error does not contain the error label `NoWritesPerformed`.
assert "NoWritesPerformed" not in exc.exception.errors["errorLabels"]
async def test_overload_then_nonoverload_retries_increased_writes(self) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error
# code `91` (ShutdownInProgress) and `RetryableError` and `SystemOverloadedError` labels.
overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 91,
},
}
# Configure a fail point with error code `91` (ShutdownInProgress) with the `RetryableError` and `RetryableWriteError` error labels.
non_overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 91,
"errorLabels": ["RetryableError", "RetryableWriteError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the fail point command only if the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(non_overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = await self.async_rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
with self.assertRaises(PyMongoError):
await client.test.test.insert_one({"x": 1})
started_inserts = [e for e in listener.started_events if e.command_name == "insert"]
self.assertEqual(len(started_inserts), MAX_ADAPTIVE_RETRIES + 1)
async def test_backoff_is_not_applied_for_non_overload_errors(self):
if _IS_SYNC:
mock_target = "pymongo.synchronous.helpers._RetryPolicy.backoff"
else:
mock_target = "pymongo.asynchronous.helpers._RetryPolicy.backoff"
# Create a client.
listener = OvertCommandListener()
# Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error
# code `91` (ShutdownInProgress) and `RetryableError` and `SystemOverloadedError` labels.
overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 91,
},
}
# Configure a fail point with error code `91` (ShutdownInProgress) with only the `RetryableError` error label.
non_overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 91,
"errorLabels": ["RetryableError", "RetryableWriteError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the fail point command only if the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(non_overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = await self.async_rs_client(event_listeners=[listener])
self.configure_fail_point_sync(overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Perform a findOne operation with coll. Expect the operation to fail.
with mock.patch(mock_target, return_value=0) as mock_backoff:
with self.assertRaises(PyMongoError):
await client.test.test.insert_one({})
# Assert that backoff was applied only once for the initial overload error and not for the subsequent non-overload retryable errors.
self.assertEqual(mock_backoff.call_count, 1)
if __name__ == "__main__":
unittest.main()

View File

@ -15,7 +15,6 @@
"""Test the client_session module."""
from __future__ import annotations
import asyncio
import copy
import sys
import time
@ -24,8 +23,6 @@ from io import BytesIO
from test.asynchronous.helpers import ExceptionCatchingTask
from typing import Any, Callable, List, Set, Tuple
from pymongo.synchronous.mongo_client import MongoClient
sys.path[0:0] = [""]
from test.asynchronous import (
@ -45,7 +42,7 @@ from test.utils_shared import (
from bson import DBRef
from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket
from pymongo import ASCENDING, AsyncMongoClient, _csot, monitoring
from pymongo import ASCENDING, AsyncMongoClient, monitoring
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.helpers import anext
@ -938,6 +935,39 @@ class TestSession(AsyncIntegrationTest):
await s2.end_session()
async def test_getmore_preserves_lsid_after_session_support_lost(self):
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
coll = client.pymongo_test.test
await coll.drop()
await coll.insert_many([{"x": i} for i in range(10)])
self.addAsyncCleanup(coll.drop)
async with client.start_session() as s:
cursor = coll.find({}, batch_size=2, session=s)
await anext(cursor)
find_event = next(e for e in listener.started_events if e.command_name == "find")
lsid = find_event.command["lsid"]
# Simulate a node stepping down: mark idle connections as not supporting sessions.
for server in client._topology._servers.values():
for conn in server.pool.conns:
conn.supports_sessions = False
listener.reset()
await cursor.to_list()
getmore_events = [e for e in listener.started_events if e.command_name == "getMore"]
self.assertGreater(len(getmore_events), 0, "expected at least one getMore command")
for event in getmore_events:
self.assertIn(
"lsid", event.command, "getMore must include lsid when session is materialized"
)
self.assertEqual(
lsid, event.command["lsid"], "getMore lsid must match the session lsid from find"
)
class TestCausalConsistency(AsyncUnitTest):
listener: SessionTestListener

View File

@ -16,9 +16,13 @@
from __future__ import annotations
import asyncio
import random
import sys
import time
from io import BytesIO
from unittest.mock import patch
import pymongo
from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket
from pymongo.asynchronous.pool import PoolState
from pymongo.server_selectors import writable_server_selector
@ -45,7 +49,9 @@ from pymongo.errors import (
CollectionInvalid,
ConfigurationError,
ConnectionFailure,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
OperationFailure,
)
from pymongo.operations import IndexModel, InsertOne
@ -434,7 +440,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
await self.configure_fail_point(client, command_args)
@async_client_context.require_transactions
async def test_callback_raises_custom_error(self):
async def test_1_callback_raises_custom_error(self):
class _MyException(Exception):
pass
@ -446,7 +452,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
await s.with_transaction(raise_error)
@async_client_context.require_transactions
async def test_callback_returns_value(self):
async def test_2_callback_returns_value(self):
async def callback(_):
return "Foo"
@ -474,7 +480,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
self.assertEqual(await s.with_transaction(callback), "Foo")
@async_client_context.require_transactions
async def test_callback_not_retried_after_timeout(self):
async def test_3_1_callback_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test
@ -495,14 +501,16 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
listener.reset()
async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure):
with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@async_client_context.require_test_commands
@async_client_context.require_transactions
async def test_callback_not_retried_after_commit_timeout(self):
async def test_3_2_callback_not_retried_after_commit_timeout(self):
listener = OvertCommandListener()
client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test
@ -529,14 +537,16 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure):
with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@async_client_context.require_test_commands
@async_client_context.require_transactions
async def test_commit_not_retried_after_timeout(self):
async def test_3_3_commit_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test
@ -560,7 +570,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(ConnectionFailure):
with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback)
# One insert for the callback and two commits (includes the automatic
@ -568,6 +578,40 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
self.assertEqual(
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
)
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult"))
@async_client_context.require_transactions
async def test_callback_not_retried_after_csot_timeout(self):
listener = OvertCommandListener()
client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test
async def callback(session):
await coll.insert_one({}, session=session)
err: dict = {
"ok": 0,
"errmsg": "Transaction 7819 has been aborted.",
"code": 251,
"codeName": "NoSuchTransaction",
"errorLabels": ["TransientTransactionError"],
}
raise OperationFailure(err["errmsg"], err["code"], err)
# Create the collection.
await coll.insert_one({})
listener.reset()
async with client.start_session() as s:
with pymongo.timeout(1.0):
with self.assertRaises(ExecutionTimeout):
await s.with_transaction(callback)
# At least two attempts: the original and one or more retries.
inserts = len([x for x in listener.started_command_names() if x == "insert"])
aborts = len([x for x in listener.started_command_names() if x == "abortTransaction"])
self.assertGreaterEqual(inserts, 2)
self.assertGreaterEqual(aborts, 2)
# Tested here because this supports Motor's convenient transactions API.
@async_client_context.require_transactions
@ -606,6 +650,63 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
await s.with_transaction(callback)
self.assertFalse(s.in_transaction)
@async_client_context.require_test_commands
@async_client_context.require_transactions
async def test_4_retry_backoff_is_enforced(self):
client = async_client_context.client
coll = client[self.db.name].test
end = start = no_backoff_time = 0
# Make random.random always return 0 (no backoff)
with patch.object(random, "random", return_value=0):
# set fail point to trigger transaction failure and trigger backoff
await self.set_fail_point(
{
"configureFailPoint": "failCommand",
"mode": {"times": 13},
"data": {
"failCommands": ["commitTransaction"],
"errorCode": 251,
},
}
)
self.addAsyncCleanup(
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
)
async def callback(session):
await coll.insert_one({}, session=session)
start = time.monotonic()
async with self.client.start_session() as s:
await s.with_transaction(callback)
end = time.monotonic()
no_backoff_time = end - start
# Make random.random always return 1 (max backoff)
with patch.object(random, "random", return_value=1):
# set fail point to trigger transaction failure and trigger backoff
await self.set_fail_point(
{
"configureFailPoint": "failCommand",
"mode": {
"times": 13
}, # sufficiently high enough such that the time effect of backoff is noticeable
"data": {
"failCommands": ["commitTransaction"],
"errorCode": 251,
},
}
)
self.addAsyncCleanup(
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
)
start = time.monotonic()
async with self.client.start_session() as s:
await s.with_transaction(callback)
end = time.monotonic()
self.assertLess(abs(end - start - (no_backoff_time + 2.2)), 1) # sum of 13 backoffs is 2.2
class TestOptionsInsideTransactionProse(AsyncTransactionsBase):
@async_client_context.require_transactions

View File

@ -1464,11 +1464,6 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
self.assertListEqual(sorted_expected_documents, actual_documents)
async def run_scenario(self, spec, uri=None):
# Kill all sessions before and after each test to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
await self.kill_all_sessions()
# Handle flaky tests.
flaky_tests = [
("PYTHON-5170", ".*test_discovery_and_monitoring.*"),
@ -1504,6 +1499,15 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
if skip_reason is not None:
raise unittest.SkipTest(f"{skip_reason}")
# Kill all sessions after each test with transactions to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
for op in spec["operations"]:
name = op["name"]
if name == "startTransaction" or name == "withTransaction":
self.addAsyncCleanup(self.kill_all_sessions)
break
# process createEntities
self._uri = uri
self.entity_map = EntityMapUtil(self)

View File

@ -16,43 +16,13 @@
from __future__ import annotations
import asyncio
import functools
import os
import time
import unittest
from collections import abc
from inspect import iscoroutinefunction
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
from test.asynchronous import async_client_context
from test.asynchronous.helpers import ConcurrentRunner
from test.utils_shared import (
CMAPListener,
CompareType,
EventListener,
OvertCommandListener,
ScenarioDict,
ServerAndTopologyEventListener,
camel_to_snake,
camel_to_snake_args,
parse_spec_options,
prepare_spec_arguments,
)
from typing import List
from test.utils_shared import ScenarioDict
from bson import ObjectId, decode, encode, json_util
from bson.binary import Binary
from bson.int64 import Int64
from bson.son import SON
from gridfs import GridFSBucket
from gridfs.asynchronous.grid_file import AsyncGridFSBucket
from pymongo.asynchronous import client_session
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
from bson import json_util
from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult, _WriteResult
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
@ -219,597 +189,3 @@ class AsyncSpecTestCreator:
self._create_tests()
else:
asyncio.run(self._create_tests())
class AsyncSpecRunner(AsyncIntegrationTest):
mongos_clients: List
knobs: client_knobs
listener: EventListener
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.mongos_clients = []
# Speed up the tests by decreasing the heartbeat frequency.
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
self.knobs.enable()
self.targets = {}
self.listener = None # type: ignore
self.pool_listener = None
self.server_listener = None
self.maxDiff = None
async def asyncTearDown(self) -> None:
self.knobs.disable()
async def set_fail_point(self, command_args):
clients = self.mongos_clients if self.mongos_clients else [self.client]
for client in clients:
await self.configure_fail_point(client, command_args)
async def targeted_fail_point(self, session, fail_point):
"""Run the targetedFailPoint test operation.
Enable the fail point on the session's pinned mongos.
"""
clients = {c.address: c for c in self.mongos_clients}
client = clients[session._pinned_address]
await self.configure_fail_point(client, fail_point)
self.addAsyncCleanup(self.set_fail_point, {"mode": "off"})
def assert_session_pinned(self, session):
"""Run the assertSessionPinned test operation.
Assert that the given session is pinned.
"""
self.assertIsNotNone(session._transaction.pinned_address)
def assert_session_unpinned(self, session):
"""Run the assertSessionUnpinned test operation.
Assert that the given session is not pinned.
"""
self.assertIsNone(session._pinned_address)
self.assertIsNone(session._transaction.pinned_address)
async def assert_collection_exists(self, database, collection):
"""Run the assertCollectionExists test operation."""
db = self.client[database]
self.assertIn(collection, await db.list_collection_names())
async def assert_collection_not_exists(self, database, collection):
"""Run the assertCollectionNotExists test operation."""
db = self.client[database]
self.assertNotIn(collection, await db.list_collection_names())
async def assert_index_exists(self, database, collection, index):
"""Run the assertIndexExists test operation."""
coll = self.client[database][collection]
self.assertIn(index, [doc["name"] async for doc in await coll.list_indexes()])
async def assert_index_not_exists(self, database, collection, index):
"""Run the assertIndexNotExists test operation."""
coll = self.client[database][collection]
self.assertNotIn(index, [doc["name"] async for doc in await coll.list_indexes()])
async def wait(self, ms):
"""Run the "wait" test operation."""
await asyncio.sleep(ms / 1000.0)
def assertErrorLabelsContain(self, exc, expected_labels):
labels = [l for l in expected_labels if exc.has_error_label(l)]
self.assertEqual(labels, expected_labels)
def assertErrorLabelsOmit(self, exc, omit_labels):
for label in omit_labels:
self.assertFalse(
exc.has_error_label(label), msg=f"error labels should not contain {label}"
)
async def kill_all_sessions(self):
clients = self.mongos_clients if self.mongos_clients else [self.client]
for client in clients:
try:
await client.admin.command("killAllSessions", [])
except (OperationFailure, AutoReconnect):
# "operation was interrupted" by killing the command's
# own session.
# On 8.0+ killAllSessions sometimes returns a network error.
pass
def check_command_result(self, expected_result, result):
# Only compare the keys in the expected result.
filtered_result = {}
for key in expected_result:
try:
filtered_result[key] = result[key]
except KeyError:
pass
self.assertEqual(filtered_result, expected_result)
# TODO: factor the following function with test_crud.py.
def check_result(self, expected_result, result):
if isinstance(result, _WriteResult):
for res in expected_result:
prop = camel_to_snake(res)
# SPEC-869: Only BulkWriteResult has upserted_count.
if prop == "upserted_count" and not isinstance(result, BulkWriteResult):
if result.upserted_id is not None:
upserted_count = 1
else:
upserted_count = 0
self.assertEqual(upserted_count, expected_result[res], prop)
elif prop == "inserted_ids":
# BulkWriteResult does not have inserted_ids.
if isinstance(result, BulkWriteResult):
self.assertEqual(len(expected_result[res]), result.inserted_count)
else:
# InsertManyResult may be compared to [id1] from the
# crud spec or {"0": id1} from the retryable write spec.
ids = expected_result[res]
if isinstance(ids, dict):
ids = [ids[str(i)] for i in range(len(ids))]
self.assertEqual(ids, result.inserted_ids, prop)
elif prop == "upserted_ids":
# Convert indexes from strings to integers.
ids = expected_result[res]
expected_ids = {}
for str_index in ids:
expected_ids[int(str_index)] = ids[str_index]
self.assertEqual(expected_ids, result.upserted_ids, prop)
else:
self.assertEqual(getattr(result, prop), expected_result[res], prop)
return True
else:
def _helper(expected_result, result):
if isinstance(expected_result, abc.Mapping):
for i in expected_result.keys():
self.assertEqual(expected_result[i], result[i])
elif isinstance(expected_result, list):
for i, k in zip(expected_result, result):
_helper(i, k)
else:
self.assertEqual(expected_result, result)
_helper(expected_result, result)
return None
def get_object_name(self, op):
"""Allow subclasses to override handling of 'object'
Transaction spec says 'object' is required.
"""
return op["object"]
@staticmethod
def parse_options(opts):
return parse_spec_options(opts)
async def run_operation(self, sessions, collection, operation):
original_collection = collection
name = camel_to_snake(operation["name"])
if name == "run_command":
name = "command"
elif name == "download_by_name":
name = "open_download_stream_by_name"
elif name == "download":
name = "open_download_stream"
elif name == "map_reduce":
self.skipTest("PyMongo does not support mapReduce")
elif name == "count":
self.skipTest("PyMongo does not support count")
database = collection.database
collection = database.get_collection(collection.name)
if "collectionOptions" in operation:
collection = collection.with_options(
**self.parse_options(operation["collectionOptions"])
)
object_name = self.get_object_name(operation)
if object_name == "gridfsbucket":
# Only create the GridFSBucket when we need it (for the gridfs
# retryable reads tests).
obj = AsyncGridFSBucket(database, bucket_name=collection.name)
else:
objects = {
"client": database.client,
"database": database,
"collection": collection,
"testRunner": self,
}
objects.update(sessions)
obj = objects[object_name]
# Combine arguments with options and handle special cases.
arguments = operation.get("arguments", {})
arguments.update(arguments.pop("options", {}))
self.parse_options(arguments)
cmd = getattr(obj, name)
with_txn_callback = functools.partial(
self.run_operations, sessions, original_collection, in_with_transaction=True
)
prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback)
if name == "run_on_thread":
args = {"sessions": sessions, "collection": collection}
args.update(arguments)
arguments = args
if not _IS_SYNC and iscoroutinefunction(cmd):
result = await cmd(**dict(arguments))
else:
result = cmd(**dict(arguments))
# Cleanup open change stream cursors.
if name == "watch":
self.addAsyncCleanup(result.close)
if name == "aggregate":
if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]:
# Read from the primary to ensure causal consistency.
out = collection.database.get_collection(
arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY
)
return out.find()
if "download" in name:
result = Binary(result.read())
if isinstance(result, AsyncCursor) or isinstance(result, AsyncCommandCursor):
return await result.to_list()
return result
def allowable_errors(self, op):
"""Allow encryption spec to override expected error classes."""
return (PyMongoError,)
async def _run_op(self, sessions, collection, op, in_with_transaction):
expected_result = op.get("result")
if expect_error(op):
with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context:
await self.run_operation(sessions, collection, op.copy())
exc = context.exception
if expect_error_message(expected_result):
if isinstance(exc, BulkWriteError):
errmsg = str(exc.details).lower()
else:
errmsg = str(exc).lower()
self.assertIn(expected_result["errorContains"].lower(), errmsg)
if expect_error_code(expected_result):
self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName"))
if expect_error_labels_contain(expected_result):
self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"])
if expect_error_labels_omit(expected_result):
self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"])
if expect_timeout_error(expected_result):
self.assertIsInstance(exc, PyMongoError)
if not exc.timeout:
# Re-raise the exception for better diagnostics.
raise exc
# Reraise the exception if we're in the with_transaction
# callback.
if in_with_transaction:
raise context.exception
else:
result = await self.run_operation(sessions, collection, op.copy())
if "result" in op:
if op["name"] == "runCommand":
self.check_command_result(expected_result, result)
else:
self.check_result(expected_result, result)
async def run_operations(self, sessions, collection, ops, in_with_transaction=False):
for op in ops:
await self._run_op(sessions, collection, op, in_with_transaction)
# TODO: factor with test_command_monitoring.py
def check_events(self, test, listener, session_ids):
events = listener.started_events
if not len(test["expectations"]):
return
# Give a nicer message when there are missing or extra events
cmds = decode_raw([event.command for event in events])
self.assertEqual(len(events), len(test["expectations"]), cmds)
for i, expectation in enumerate(test["expectations"]):
event_type = next(iter(expectation))
event = events[i]
# The tests substitute 42 for any number other than 0.
if event.command_name == "getMore" and event.command["getMore"]:
event.command["getMore"] = Int64(42)
elif event.command_name == "killCursors":
event.command["cursors"] = [Int64(42)]
elif event.command_name == "update":
# TODO: remove this once PYTHON-1744 is done.
# Add upsert and multi fields back into expectations.
updates = expectation[event_type]["command"]["updates"]
for update in updates:
update.setdefault("upsert", False)
update.setdefault("multi", False)
# Replace afterClusterTime: 42 with actual afterClusterTime.
expected_cmd = expectation[event_type]["command"]
expected_read_concern = expected_cmd.get("readConcern")
if expected_read_concern is not None:
time = expected_read_concern.get("afterClusterTime")
if time == 42:
actual_time = event.command.get("readConcern", {}).get("afterClusterTime")
if actual_time is not None:
expected_read_concern["afterClusterTime"] = actual_time
recovery_token = expected_cmd.get("recoveryToken")
if recovery_token == 42:
expected_cmd["recoveryToken"] = CompareType(dict)
# Replace lsid with a name like "session0" to match test.
if "lsid" in event.command:
for name, lsid in session_ids.items():
if event.command["lsid"] == lsid:
event.command["lsid"] = name
break
for attr, expected in expectation[event_type].items():
actual = getattr(event, attr)
expected = wrap_types(expected)
if isinstance(expected, dict):
for key, val in expected.items():
if val is None:
if key in actual:
self.fail(f"Unexpected key [{key}] in {actual!r}")
elif key not in actual:
self.fail(f"Expected key [{key}] in {actual!r}")
else:
self.assertEqual(
val, decode_raw(actual[key]), f"Key [{key}] in {actual}"
)
else:
self.assertEqual(actual, expected)
def maybe_skip_scenario(self, test):
if test.get("skipReason"):
self.skipTest(test.get("skipReason"))
def get_scenario_db_name(self, scenario_def):
"""Allow subclasses to override a test's database name."""
return scenario_def["database_name"]
def get_scenario_coll_name(self, scenario_def):
"""Allow subclasses to override a test's collection name."""
return scenario_def["collection_name"]
def get_outcome_coll_name(self, outcome, collection):
"""Allow subclasses to override outcome collection."""
return collection.name
async def run_test_ops(self, sessions, collection, test):
"""Added to allow retryable writes spec to override a test's
operation.
"""
await self.run_operations(sessions, collection, test["operations"])
def parse_client_options(self, opts):
"""Allow encryption spec to override a clientOptions parsing."""
return opts
async def setup_scenario(self, scenario_def):
"""Allow specs to override a test's setup."""
db_name = self.get_scenario_db_name(scenario_def)
coll_name = self.get_scenario_coll_name(scenario_def)
documents = scenario_def["data"]
# Setup the collection with as few majority writes as possible.
db = async_client_context.client.get_database(db_name)
coll_exists = bool(await db.list_collection_names(filter={"name": coll_name}))
if coll_exists:
await db[coll_name].delete_many({})
# Only use majority wc only on the final write.
wc = WriteConcern(w="majority")
if documents:
db.get_collection(coll_name, write_concern=wc).insert_many(documents)
elif not coll_exists:
# Ensure collection exists.
await db.create_collection(coll_name, write_concern=wc)
async def run_scenario(self, scenario_def, test):
self.maybe_skip_scenario(test)
# Kill all sessions before and after each test to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
await self.kill_all_sessions()
self.addAsyncCleanup(self.kill_all_sessions)
await self.setup_scenario(scenario_def)
database_name = self.get_scenario_db_name(scenario_def)
collection_name = self.get_scenario_coll_name(scenario_def)
# SPEC-1245 workaround StaleDbVersion on distinct
for c in self.mongos_clients:
await c[database_name][collection_name].distinct("x")
# Configure the fail point before creating the client.
if "failPoint" in test:
fp = test["failPoint"]
await self.set_fail_point(fp)
self.addAsyncCleanup(
self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"}
)
listener = OvertCommandListener()
pool_listener = CMAPListener()
server_listener = ServerAndTopologyEventListener()
# Create a new client, to avoid interference from pooled sessions.
client_options = self.parse_client_options(test["clientOptions"])
use_multi_mongos = test["useMultipleMongoses"]
host = None
if use_multi_mongos:
if async_client_context.load_balancer:
host = async_client_context.MULTI_MONGOS_LB_URI
elif async_client_context.is_mongos:
host = async_client_context.mongos_seeds()
client = await self.async_rs_client(
h=host, event_listeners=[listener, pool_listener, server_listener], **client_options
)
self.scenario_client = client
self.listener = listener
self.pool_listener = pool_listener
self.server_listener = server_listener
# Create session0 and session1.
sessions = {}
session_ids = {}
for i in range(2):
# Don't attempt to create sessions if they are not supported by
# the running server version.
if not async_client_context.sessions_enabled:
break
session_name = "session%d" % i
opts = camel_to_snake_args(test["sessionOptions"][session_name])
if "default_transaction_options" in opts:
txn_opts = self.parse_options(opts["default_transaction_options"])
txn_opts = client_session.TransactionOptions(**txn_opts)
opts["default_transaction_options"] = txn_opts
s = client.start_session(**dict(opts))
sessions[session_name] = s
# Store lsid so we can access it after end_session, in check_events.
session_ids[session_name] = s.session_id
self.addAsyncCleanup(end_sessions, sessions)
collection = client[database_name][collection_name]
await self.run_test_ops(sessions, collection, test)
await end_sessions(sessions)
self.check_events(test, listener, session_ids)
# Disable fail points.
if "failPoint" in test:
fp = test["failPoint"]
await self.set_fail_point(
{"configureFailPoint": fp["configureFailPoint"], "mode": "off"}
)
# Assert final state is expected.
outcome = test["outcome"]
expected_c = outcome.get("collection")
if expected_c is not None:
outcome_coll_name = self.get_outcome_coll_name(outcome, collection)
# Read from the primary with local read concern to ensure causal
# consistency.
outcome_coll = async_client_context.client[collection.database.name].get_collection(
outcome_coll_name,
read_preference=ReadPreference.PRIMARY,
read_concern=ReadConcern("local"),
)
actual_data = await outcome_coll.find(sort=[("_id", 1)]).to_list()
# The expected data needs to be the left hand side here otherwise
# CompareType(Binary) doesn't work.
self.assertEqual(wrap_types(expected_c["data"]), actual_data)
def expect_any_error(op):
if isinstance(op, dict):
return op.get("error")
return False
def expect_error_message(expected_result):
if isinstance(expected_result, dict):
return isinstance(expected_result["errorContains"], str)
return False
def expect_error_code(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorCodeName"]
return False
def expect_error_labels_contain(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorLabelsContain"]
return False
def expect_error_labels_omit(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorLabelsOmit"]
return False
def expect_timeout_error(expected_result):
if isinstance(expected_result, dict):
return expected_result["isTimeoutError"]
return False
def expect_error(op):
expected_result = op.get("result")
return (
expect_any_error(op)
or expect_error_message(expected_result)
or expect_error_code(expected_result)
or expect_error_labels_contain(expected_result)
or expect_error_labels_omit(expected_result)
or expect_timeout_error(expected_result)
)
async def end_sessions(sessions):
for s in sessions.values():
# Aborts the transaction if it's open.
await s.end_session()
def decode_raw(val):
"""Decode RawBSONDocuments in the given container."""
if isinstance(val, (list, abc.Mapping)):
return decode(encode({"v": val}))["v"]
return val
TYPES = {
"binData": Binary,
"long": Int64,
"int": int,
"string": str,
"objectId": ObjectId,
"object": dict,
"array": list,
}
def wrap_types(val):
"""Support $$type assertion in command results."""
if isinstance(val, list):
return [wrap_types(v) for v in val]
if isinstance(val, abc.Mapping):
typ = val.get("$$type")
if typ:
if isinstance(typ, str):
types = TYPES[typ]
else:
types = tuple(TYPES[t] for t in typ)
return CompareType(types)
d = {}
for key in val:
d[key] = wrap_types(val[key])
return d
return val

View File

@ -42,6 +42,91 @@
}
],
"tests": [
{
"description": "disambiguatedPaths is not present when showExpandedEvents is false/unset",
"runOnRequirements": [
{
"minServerVersion": "6.1.0",
"maxServerVersion": "8.1.99",
"topologies": [
"replicaset",
"load-balanced",
"sharded"
],
"serverless": "forbid"
},
{
"minServerVersion": "8.2.1",
"topologies": [
"replicaset",
"load-balanced",
"sharded"
],
"serverless": "forbid"
}
],
"operations": [
{
"name": "insertOne",
"object": "collection0",
"arguments": {
"document": {
"_id": 1,
"a": {
"1": 1
}
}
}
},
{
"name": "createChangeStream",
"object": "collection0",
"arguments": {
"pipeline": []
},
"saveResultAsEntity": "changeStream0"
},
{
"name": "updateOne",
"object": "collection0",
"arguments": {
"filter": {
"_id": 1
},
"update": {
"$set": {
"a.1": 2
}
}
}
},
{
"name": "iterateUntilDocumentOrError",
"object": "changeStream0",
"expectResult": {
"operationType": "update",
"ns": {
"db": "database0",
"coll": "collection0"
},
"updateDescription": {
"updatedFields": {
"$$exists": true
},
"removedFields": {
"$$exists": true
},
"truncatedArrays": {
"$$exists": true
},
"disambiguatedPaths": {
"$$exists": false
}
}
}
}
]
},
{
"description": "disambiguatedPaths is present on updateDescription when an ambiguous path is present",
"operations": [

View File

@ -63,47 +63,6 @@
}
]
},
{
"description": "nsType is present when creating timeseries",
"operations": [
{
"name": "dropCollection",
"object": "database0",
"arguments": {
"collection": "foo"
}
},
{
"name": "createChangeStream",
"object": "database0",
"arguments": {
"pipeline": [],
"showExpandedEvents": true
},
"saveResultAsEntity": "changeStream0"
},
{
"name": "createCollection",
"object": "database0",
"arguments": {
"collection": "foo",
"timeseries": {
"timeField": "time",
"metaField": "meta",
"granularity": "minutes"
}
}
},
{
"name": "iterateUntilDocumentOrError",
"object": "changeStream0",
"expectResult": {
"operationType": "create",
"nsType": "timeseries"
}
}
]
},
{
"description": "nsType is present when creating views",
"operations": [

View File

@ -0,0 +1,111 @@
{
"description": "tests that connections are returned to the pool on retry attempts for overload errors",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"topologies": [
"replicaset",
"sharded",
"load-balanced"
]
}
],
"createEntities": [
{
"client": {
"id": "client",
"useMultipleMongoses": false,
"observeEvents": [
"connectionCheckedOutEvent",
"connectionCheckedInEvent"
]
}
},
{
"client": {
"id": "fail_point_client",
"useMultipleMongoses": false
}
},
{
"database": {
"id": "database",
"client": "client",
"databaseName": "backpressure-connection-checkin"
}
},
{
"collection": {
"id": "collection",
"database": "database",
"collectionName": "coll"
}
}
],
"tests": [
{
"description": "overload error retry attempts return connections to the pool",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "fail_point_client",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"find"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 2
}
}
}
},
{
"name": "find",
"object": "collection",
"arguments": {
"filter": {}
},
"expectError": {
"isError": true,
"isClientError": false
}
}
],
"expectEvents": [
{
"client": "client",
"eventType": "cmap",
"events": [
{
"connectionCheckedOutEvent": {}
},
{
"connectionCheckedInEvent": {}
},
{
"connectionCheckedOutEvent": {}
},
{
"connectionCheckedInEvent": {}
},
{
"connectionCheckedOutEvent": {}
},
{
"connectionCheckedInEvent": {}
}
]
}
]
}
]
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,253 @@
{
"description": "getMore-retried-backpressure",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"minServerVersion": "4.4"
}
],
"createEntities": [
{
"client": {
"id": "client0",
"useMultipleMongoses": false,
"observeEvents": [
"commandStartedEvent",
"commandFailedEvent",
"commandSucceededEvent"
]
}
},
{
"client": {
"id": "failPointClient",
"useMultipleMongoses": false
}
},
{
"database": {
"id": "db",
"client": "client0",
"databaseName": "default"
}
},
{
"collection": {
"id": "coll",
"database": "db",
"collectionName": "default"
}
}
],
"initialData": [
{
"databaseName": "default",
"collectionName": "default",
"documents": [
{
"a": 1
},
{
"a": 2
},
{
"a": 3
}
]
}
],
"tests": [
{
"description": "getMores are retried",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "failPointClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"getMore"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 2
}
}
}
},
{
"name": "find",
"object": "coll",
"arguments": {
"batchSize": 2,
"filter": {},
"sort": {
"a": 1
}
},
"expectResult": [
{
"a": 1
},
{
"a": 2
},
{
"a": 3
}
]
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"commandName": "find"
}
},
{
"commandSucceededEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandSucceededEvent": {
"commandName": "getMore"
}
}
]
}
]
},
{
"description": "getMores are retried maxAttempts=2 times",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "failPointClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"getMore"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 2
}
}
}
},
{
"name": "find",
"arguments": {
"batchSize": 2,
"filter": {}
},
"object": "coll",
"expectError": {
"isError": true,
"isClientError": false
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"commandName": "find"
}
},
{
"commandSucceededEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "killCursors"
}
},
{
"commandSucceededEvent": {
"commandName": "killCursors"
}
}
]
}
]
}
]
}

View File

@ -4,6 +4,7 @@
"runOnRequirements": [
{
"minServerVersion": "8.2.0",
"maxServerVersion": "8.99.99",
"topologies": [
"replicaset",
"sharded",

View File

@ -4,6 +4,7 @@
"runOnRequirements": [
{
"minServerVersion": "8.2.0",
"maxServerVersion": "8.99.99",
"topologies": [
"replicaset",
"sharded",

View File

@ -4,6 +4,7 @@
"runOnRequirements": [
{
"minServerVersion": "8.2.0",
"maxServerVersion": "8.99.99",
"topologies": [
"replicaset",
"sharded",

View File

@ -126,7 +126,7 @@
],
"tests": [
{
"description": "Insert QE suffixPreview",
"description": "Insert QE substringPreview",
"operations": [
{
"name": "insertOne",

View File

@ -4,6 +4,7 @@
"runOnRequirements": [
{
"minServerVersion": "8.2.0",
"maxServerVersion": "8.99.99",
"topologies": [
"replicaset",
"sharded",

View File

@ -0,0 +1,485 @@
{
"description": "fle2v2-InsertFind-keyAltName",
"schemaVersion": "1.25",
"runOnRequirements": [
{
"minServerVersion": "7.0.0",
"topologies": [
"replicaset",
"sharded",
"load-balanced"
],
"csfle": {
"minLibmongocryptVersion": "1.18.0"
}
}
],
"createEntities": [
{
"client": {
"id": "client0",
"autoEncryptOpts": {
"keyVaultNamespace": "keyvault.datakeys",
"kmsProviders": {
"local": {
"key": "Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
}
},
"encryptedFieldsMap": {
"default.default": {
"fields": [
{
"path": "encryptedIndexed",
"bsonType": "string",
"queries": {
"queryType": "equality",
"contention": {
"$numberLong": "0"
}
},
"keyAltName": "altname"
}
]
}
}
},
"observeEvents": [
"commandStartedEvent"
]
}
},
{
"database": {
"id": "db",
"client": "client0",
"databaseName": "default"
}
},
{
"collection": {
"id": "coll",
"database": "db",
"collectionName": "default"
}
},
{
"client": {
"id": "client_unencrypted",
"observeEvents": [
"commandStartedEvent"
]
}
},
{
"database": {
"id": "db_unencrypted",
"client": "client_unencrypted",
"databaseName": "default"
}
},
{
"collection": {
"id": "coll_unencrypted",
"database": "db_unencrypted",
"collectionName": "default"
}
}
],
"initialData": [
{
"databaseName": "default",
"collectionName": "default",
"documents": [],
"createOptions": {
"encryptedFields": {
"fields": [
{
"keyId": {
"$binary": {
"base64": "EjRWeBI0mHYSNBI0VniQEg==",
"subType": "04"
}
},
"path": "encryptedIndexed",
"bsonType": "string",
"queries": {
"queryType": "equality",
"contention": {
"$numberLong": "0"
}
}
}
]
}
}
},
{
"databaseName": "keyvault",
"collectionName": "datakeys",
"documents": [
{
"_id": {
"$binary": {
"base64": "EjRWeBI0mHYSNBI0VniQEg==",
"subType": "04"
}
},
"keyMaterial": {
"$binary": {
"base64": "sHe0kz57YW7v8g9VP9sf/+K1ex4JqKc5rf/URX3n3p8XdZ6+15uXPaSayC6adWbNxkFskuMCOifDoTT+rkqMtFkDclOy884RuGGtUysq3X7zkAWYTKi8QAfKkajvVbZl2y23UqgVasdQu3OVBQCrH/xY00nNAs/52e958nVjBuzQkSb1T8pKJAyjZsHJ60+FtnfafDZSTAIBJYn7UWBCwQ==",
"subType": "00"
}
},
"creationDate": {
"$date": {
"$numberLong": "1648914851981"
}
},
"updateDate": {
"$date": {
"$numberLong": "1648914851981"
}
},
"status": {
"$numberInt": "0"
},
"masterKey": {
"provider": "local"
},
"keyAltNames": [
"altname"
]
}
]
}
],
"tests": [
{
"description": "Insert and find FLE2 indexed field",
"operations": [
{
"name": "insertOne",
"arguments": {
"document": {
"_id": 1,
"encryptedIndexed": "123"
}
},
"object": "coll"
},
{
"name": "find",
"arguments": {
"filter": {
"encryptedIndexed": "123"
}
},
"object": "coll",
"expectResult": [
{
"_id": 1,
"encryptedIndexed": "123"
}
]
},
{
"name": "find",
"object": "coll_unencrypted",
"arguments": {
"filter": {}
},
"expectResult": [
{
"_id": 1,
"encryptedIndexed": {
"$$type": "binData"
},
"__safeContent__": [
{
"$binary": {
"base64": "31eCYlbQoVboc5zwC8IoyJVSkag9PxREka8dkmbXJeY=",
"subType": "00"
}
}
]
}
]
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"find": "datakeys",
"filter": {
"$or": [
{
"_id": {
"$in": []
}
},
{
"keyAltNames": {
"$in": [
"altname"
]
}
}
]
},
"$db": "keyvault",
"readConcern": {
"level": "majority"
}
},
"commandName": "find"
}
},
{
"commandStartedEvent": {
"command": {
"insert": "default",
"documents": [
{
"_id": 1,
"encryptedIndexed": {
"$$type": "binData"
}
}
],
"ordered": true,
"encryptionInformation": {
"type": 1,
"schema": {
"default.default": {
"escCollection": "enxcol_.default.esc",
"ecocCollection": "enxcol_.default.ecoc",
"fields": [
{
"keyId": {
"$binary": {
"base64": "EjRWeBI0mHYSNBI0VniQEg==",
"subType": "04"
}
},
"path": "encryptedIndexed",
"bsonType": "string",
"queries": {
"queryType": "equality",
"contention": {
"$numberLong": "0"
}
}
}
]
}
}
}
},
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"command": {
"find": "default",
"filter": {
"encryptedIndexed": {
"$eq": {
"$binary": {
"base64": "DIkAAAAFZAAgAAAAAPGmZcUzdE/FPILvRSyAScGvZparGI2y9rJ/vSBxgCujBXMAIAAAAACi1RjmndKqgnXy7xb22RzUbnZl1sOZRXPOC0KcJkAxmQVsACAAAAAApJtKPW4+o9B7gAynNLL26jtlB4+hq5TXResijcYet8USY20AAAAAAAAAAAAA",
"subType": "06"
}
}
}
},
"encryptionInformation": {
"type": 1,
"schema": {
"default.default": {
"escCollection": "enxcol_.default.esc",
"ecocCollection": "enxcol_.default.ecoc",
"fields": [
{
"keyId": {
"$binary": {
"base64": "EjRWeBI0mHYSNBI0VniQEg==",
"subType": "04"
}
},
"path": "encryptedIndexed",
"bsonType": "string",
"queries": {
"queryType": "equality",
"contention": {
"$numberLong": "0"
}
}
}
]
}
}
}
},
"commandName": "find"
}
}
]
}
]
},
{
"description": "Create translates keyAltName",
"operations": [
{
"name": "dropCollection",
"object": "db",
"arguments": {
"collection": "default"
}
},
{
"name": "createCollection",
"object": "db",
"arguments": {
"collection": "default"
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"drop": "enxcol_.default.esc"
},
"commandName": "drop"
}
},
{
"commandStartedEvent": {
"command": {
"drop": "enxcol_.default.ecoc"
},
"commandName": "drop"
}
},
{
"commandStartedEvent": {
"command": {
"drop": "default"
},
"commandName": "drop"
}
},
{
"commandStartedEvent": {
"command": {
"create": "enxcol_.default.esc",
"clusteredIndex": {
"key": {
"_id": 1
},
"unique": true
}
},
"commandName": "create"
}
},
{
"commandStartedEvent": {
"command": {
"create": "enxcol_.default.ecoc",
"clusteredIndex": {
"key": {
"_id": 1
},
"unique": true
}
},
"commandName": "create"
}
},
{
"commandStartedEvent": {
"command": {
"find": "datakeys",
"filter": {
"$or": [
{
"_id": {
"$in": []
}
},
{
"keyAltNames": {
"$in": [
"altname"
]
}
}
]
},
"$db": "keyvault",
"readConcern": {
"level": "majority"
}
},
"commandName": "find"
}
},
{
"commandStartedEvent": {
"command": {
"create": "default",
"encryptedFields": {
"fields": [
{
"path": "encryptedIndexed",
"bsonType": "string",
"queries": {
"queryType": "equality",
"contention": {
"$numberLong": "0"
}
},
"keyId": {
"$binary": {
"base64": "EjRWeBI0mHYSNBI0VniQEg==",
"subType": "04"
}
}
}
]
}
},
"commandName": "create"
}
},
{
"commandStartedEvent": {
"command": {
"createIndexes": "default",
"indexes": [
{
"name": "__safeContent___1",
"key": {
"__safeContent__": 1
}
}
]
},
"commandName": "createIndexes"
}
}
]
}
]
}
]
}

View File

@ -97,14 +97,22 @@
"outcome": {
"servers": {
"a:27017": {
"type": "Unknown",
"topologyVersion": null,
"type": "RSPrimary",
"setName": "rs",
"topologyVersion": {
"processId": {
"$oid": "000000000000000000000001"
},
"counter": {
"$numberLong": "1"
}
},
"pool": {
"generation": 1
"generation": 0
}
}
},
"topologyType": "ReplicaSetNoPrimary",
"topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs"
}

169
test/test_azure_helpers.py Normal file
View File

@ -0,0 +1,169 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for _azure_helpers.py.
These tests mock urlopen to avoid requiring a live Azure IMDS endpoint.
Integration tests that exercise the real endpoint are gated by environment
variables in test_on_demand_csfle.py and test_auth_oidc.py.
"""
from __future__ import annotations
import json
import sys
import unittest
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
sys.path[0:0] = [""]
from pymongo._azure_helpers import _get_azure_response
@contextmanager
def _mock_urlopen(status: int, body: str):
"""Context manager that patches ``urllib.request.urlopen`` with a fake response."""
mock_response = MagicMock()
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
mock_response.status = status
mock_response.read.return_value = body.encode("utf8")
with patch("urllib.request.urlopen", return_value=mock_response) as mock_open:
yield mock_open
class TestGetAzureResponse(unittest.TestCase):
def _call(self, resource="https://example.com/", client_id=None, timeout=5):
return _get_azure_response(resource, client_id=client_id, timeout=timeout)
def test_success_without_client_id(self):
body = json.dumps({"access_token": "tok", "expires_in": "3600"})
with _mock_urlopen(200, body) as mock_open:
result = self._call()
self.assertEqual(result["access_token"], "tok")
self.assertEqual(result["expires_in"], "3600")
# Verify client_id was NOT added to the URL
url = mock_open.call_args[0][0].full_url
self.assertNotIn("client_id", url)
def test_success_with_client_id(self):
body = json.dumps({"access_token": "tok", "expires_in": "3600"})
with _mock_urlopen(200, body) as mock_open:
result = self._call(client_id="my-client-id")
self.assertEqual(result["access_token"], "tok")
url = mock_open.call_args[0][0].full_url
self.assertIn("client_id=my-client-id", url)
def test_url_contains_resource_and_api_version(self):
body = json.dumps({"access_token": "tok", "expires_in": "3600"})
with _mock_urlopen(200, body) as mock_open:
self._call(resource="https://test-resource.example.com")
url = mock_open.call_args[0][0].full_url
self.assertIn("api-version=2018-02-01", url)
self.assertIn("resource=https://test-resource.example.com", url)
def test_request_headers(self):
body = json.dumps({"access_token": "tok", "expires_in": "3600"})
with _mock_urlopen(200, body) as mock_open:
self._call()
request = mock_open.call_args[0][0]
self.assertEqual(request.get_header("Metadata"), "true")
self.assertEqual(request.get_header("Accept"), "application/json")
def test_urlopen_exception_raises_value_error(self):
with patch("urllib.request.urlopen", side_effect=OSError("connection refused")):
with self.assertRaises(ValueError) as ctx:
self._call()
self.assertIn("Failed to acquire IMDS access token", str(ctx.exception))
def test_non_200_status_raises_value_error(self):
body = json.dumps({"error": "something went wrong"})
with _mock_urlopen(400, body):
with self.assertRaises(ValueError) as ctx:
self._call()
self.assertIn("Failed to acquire IMDS access token", str(ctx.exception))
def test_non_json_body_raises_value_error(self):
with _mock_urlopen(200, "not-json"):
with self.assertRaises(ValueError) as ctx:
self._call()
self.assertIn("Azure IMDS response must be in JSON format", str(ctx.exception))
def test_missing_access_token_raises_value_error(self):
body = json.dumps({"expires_in": "3600"})
with _mock_urlopen(200, body):
with self.assertRaises(ValueError) as ctx:
self._call()
self.assertIn("access_token", str(ctx.exception))
def test_missing_expires_in_raises_value_error(self):
body = json.dumps({"access_token": "tok"})
with _mock_urlopen(200, body):
with self.assertRaises(ValueError) as ctx:
self._call()
self.assertIn("expires_in", str(ctx.exception))
def test_empty_access_token_raises_value_error(self):
body = json.dumps({"access_token": "", "expires_in": "3600"})
with _mock_urlopen(200, body):
with self.assertRaises(ValueError) as ctx:
self._call()
self.assertIn("access_token", str(ctx.exception))
def test_empty_expires_in_raises_value_error(self):
body = json.dumps({"access_token": "tok", "expires_in": ""})
with _mock_urlopen(200, body):
with self.assertRaises(ValueError) as ctx:
self._call()
self.assertIn("expires_in", str(ctx.exception))
def test_timeout_passed_to_urlopen(self):
body = json.dumps({"access_token": "tok", "expires_in": "3600"})
with _mock_urlopen(200, body) as mock_open:
self._call(timeout=42)
_, kwargs = mock_open.call_args
self.assertEqual(kwargs["timeout"], 42)
def test_client_id_is_url_encoded(self):
"""Ensure special characters in client_id are percent-encoded."""
body = json.dumps({"access_token": "tok", "expires_in": "3600"})
with _mock_urlopen(200, body) as mock_open:
self._call(client_id="id with spaces&special=chars")
url = mock_open.call_args[0][0].full_url
# '&' and '=' must be percent-encoded so they don't inject extra query params
self.assertIn("client_id=id%20with%20spaces%26special%3Dchars", url)
# The encoded client_id should not introduce a raw '&'
# Count params: api-version, resource, client_id — exactly 3
query_string = url.split("?", 1)[1]
self.assertEqual(query_string.count("&"), 2)
if __name__ == "__main__":
unittest.main()

View File

@ -1269,6 +1269,22 @@ class TestBSON(unittest.TestCase):
encode(doc)
self.assertEqual(cm.exception.document, doc)
def test_binary_length_accounts_for_header(self):
size = 20
binary_length = 12 # 5 more than the actual 7 bytes
payload = b""
payload += struct.pack("<i", size) # document size
payload += b"\x05" # type = Binary
payload += b"a\x00" # key "a"
payload += struct.pack("<I", binary_length) # Binary length (inflated)
payload += b"\x00" # subtype 0
payload += b"\x41" * 7 # value
payload += b"\x00" # EOO
with self.assertRaises(InvalidBSON):
decode(payload)
class TestCodecOptions(unittest.TestCase):
def test_document_class(self):

View File

@ -645,6 +645,38 @@ class ClientUnitTest(UnitTest):
with self.assertWarns(UserWarning):
self.simple_client(multi_host)
def test_max_adaptive_retries(self):
# Assert that max adaptive retries defaults to 2.
c = self.simple_client(connect=False)
self.assertEqual(c.options.max_adaptive_retries, 2)
# Assert that max adaptive retries can be configured through connection or client options.
c = self.simple_client(connect=False, max_adaptive_retries=10)
self.assertEqual(c.options.max_adaptive_retries, 10)
c = self.simple_client(connect=False, maxAdaptiveRetries=10)
self.assertEqual(c.options.max_adaptive_retries, 10)
c = self.simple_client(host="mongodb://localhost/?maxAdaptiveRetries=10", connect=False)
self.assertEqual(c.options.max_adaptive_retries, 10)
def test_enable_overload_retargeting(self):
# Assert that overload retargeting defaults to false.
c = self.simple_client(connect=False)
self.assertFalse(c.options.enable_overload_retargeting)
# Assert that overload retargeting can be enabled through connection or client options.
c = self.simple_client(connect=False, enable_overload_retargeting=True)
self.assertTrue(c.options.enable_overload_retargeting)
c = self.simple_client(connect=False, enableOverloadRetargeting=True)
self.assertTrue(c.options.enable_overload_retargeting)
c = self.simple_client(
host="mongodb://localhost/?enableOverloadRetargeting=true", connect=False
)
self.assertTrue(c.options.enable_overload_retargeting)
class TestClient(IntegrationTest):
def test_multiple_uris(self):
@ -1007,7 +1039,7 @@ class TestClient(IntegrationTest):
db_names = self.client.list_database_names()
self.assertIn("pymongo_test", db_names)
self.assertIn("pymongo_test_mike", db_names)
self.assertEqual(db_names, cmd_names)
self.assertCountEqual(db_names, cmd_names)
def test_drop_database(self):
with self.assertRaises(TypeError):
@ -2634,11 +2666,11 @@ class TestClientPool(MockClientTest):
wait_until(lambda: len(c.nodes) == 1, "connect")
self.assertEqual(c.address, ("c", 3))
# Assert that we create 1 pooled connection.
# Wait for the pooled connection to be registered
listener.wait_for_event(monitoring.ConnectionReadyEvent, 1)
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1)
arbiter = c._topology.get_server_by_address(("c", 3))
self.assertEqual(len(arbiter.pool.conns), 1)
wait_until(lambda: len(arbiter.pool.conns) == 1, "create 1 pooled connection")
# Arbiter pool is marked ready.
self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 1)

View File

@ -0,0 +1,310 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Client Backpressure spec."""
from __future__ import annotations
import os
import pathlib
import sys
from time import perf_counter
from unittest.mock import patch
from pymongo.common import MAX_ADAPTIVE_RETRIES
sys.path[0:0] = [""]
from test import (
IntegrationTest,
client_context,
unittest,
)
from test.unified_format import generate_test_classes
from test.utils_shared import EventListener, OvertCommandListener
from pymongo.errors import OperationFailure, PyMongoError
_IS_SYNC = True
# Mock a system overload error.
mock_overload_error = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find", "insert", "update"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def get_mock_overload_error(times: int):
error = mock_overload_error.copy()
error["mode"] = {"times": times}
return error
class TestBackpressure(IntegrationTest):
RUN_ON_LOAD_BALANCER = True
@client_context.require_failCommand_appName
def test_retry_overload_error_command(self):
self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
with self.fail_point(fail_many):
self.db.command("find", "t")
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
self.db.command("find", "t")
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@client_context.require_failCommand_appName
def test_retry_overload_error_find(self):
self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
with self.fail_point(fail_many):
self.db.t.find_one()
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
self.db.t.find_one()
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@client_context.require_failCommand_appName
def test_retry_overload_error_insert_one(self):
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
with self.fail_point(fail_many):
self.db.t.insert_one({"x": 1})
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
self.db.t.insert_one({"x": 1})
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@client_context.require_failCommand_appName
def test_retry_overload_error_update_many(self):
# Even though update_many is not a retryable write operation, it will
# still be retried via the "RetryableError" error label.
self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
with self.fail_point(fail_many):
self.db.t.update_many({}, {"$set": {"x": 2}})
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
self.db.t.update_many({}, {"$set": {"x": 2}})
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@client_context.require_failCommand_appName
def test_retry_overload_error_getMore(self):
coll = self.db.t
coll.insert_many([{"x": 1} for _ in range(10)])
# Ensure command is retried on overload error.
fail_many = {
"configureFailPoint": "failCommand",
"mode": {"times": MAX_ADAPTIVE_RETRIES},
"data": {
"failCommands": ["getMore"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
cursor = coll.find(batch_size=2)
cursor.next()
with self.fail_point(fail_many):
cursor.to_list()
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = fail_many.copy()
fail_too_many["mode"] = {"times": MAX_ADAPTIVE_RETRIES + 1}
cursor = coll.find(batch_size=2)
cursor.next()
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
cursor.to_list()
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# Prose tests.
class TestClientBackpressure(IntegrationTest):
listener: EventListener
@classmethod
def setUpClass(cls) -> None:
cls.listener = OvertCommandListener()
@client_context.require_connection
def setUp(self) -> None:
super().setUp()
self.listener.reset()
self.app_name = self.__class__.__name__.lower()
self.client = self.rs_or_single_client(
event_listeners=[self.listener], appName=self.app_name
)
@patch("random.random")
@client_context.require_failCommand_appName
def test_01_operation_retry_uses_exponential_backoff(self, random_func):
# Drivers should test that retries do not occur immediately when a SystemOverloadedError is encountered.
# 1. let `client` be a `MongoClient`
client = self.client
# 2. let `collection` be a collection
collection = client.test.test
# 3. Now, run transactions without backoff:
# a. Configure the random number generator used for jitter to always return `0` -- this effectively disables backoff.
random_func.return_value = 0
# b. Configure the following failPoint:
fail_point = dict(
mode="alwaysOn",
data=dict(
failCommands=["insert"],
errorCode=2,
errorLabels=["SystemOverloadedError", "RetryableError"],
appName=self.app_name,
),
)
with self.fail_point(fail_point):
# c. Execute the following command. Expect that the command errors. Measure the duration of the command execution.
start0 = perf_counter()
with self.assertRaises(OperationFailure):
collection.insert_one({"a": 1})
end0 = perf_counter()
# d. Configure the random number generator used for jitter to always return `1`.
random_func.return_value = 1
# e. Execute step c again.
start1 = perf_counter()
with self.assertRaises(OperationFailure):
collection.insert_one({"a": 1})
end1 = perf_counter()
# f. Compare the times between the two runs.
# The sum of 2 backoffs is 0.3 seconds. There is a 0.3-second window to account for potential variance between the two
# runs.
self.assertTrue(abs((end1 - start1) - (end0 - start0 + 0.3)) < 0.3)
@client_context.require_failCommand_appName
def test_03_overload_retries_limited(self):
# Drivers should test that overload errors are retried a maximum of two times.
# 1. Let `client` be a `MongoClient`.
client = self.client
# 2. Let `coll` be a collection.
coll = client.pymongo_test.coll
# 3. Configure the following failpoint:
failpoint = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
# 4. Perform a find operation with `coll` that fails.
with self.fail_point(failpoint):
with self.assertRaises(PyMongoError) as error:
coll.find_one({})
# 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels.
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# 6. Assert that the total number of started commands is MAX_ADAPTIVE_RETRIES + 1.
self.assertEqual(len(self.listener.started_events), MAX_ADAPTIVE_RETRIES + 1)
@client_context.require_failCommand_appName
def test_04_overload_retries_limited_configured(self):
# Drivers should test that overload errors are retried a maximum of maxAdaptiveRetries times.
max_retries = 1
# 1. Let `client` be a `MongoClient` with `maxAdaptiveRetries=1` and command event monitoring enabled.
client = self.single_client(maxAdaptiveRetries=max_retries, event_listeners=[self.listener])
# 2. Let `coll` be a collection.
coll = client.pymongo_test.coll
# 3. Configure the following failpoint:
failpoint = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
# 4. Perform a find operation with `coll` that fails.
with self.fail_point(failpoint):
with self.assertRaises(PyMongoError) as error:
coll.find_one({})
# 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels.
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# 6. Assert that the total number of started commands is max_retries + 1.
self.assertEqual(len(self.listener.started_events), max_retries + 1)
# Location of JSON test specifications.
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "client-backpressure")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "client-backpressure")
globals().update(
generate_test_classes(
_TEST_PATH,
module=__name__,
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -219,6 +219,19 @@ class TestClientMetadataProse(IntegrationTest):
# add same metadata again
self.check_metadata_added(client, "Framework", None, None)
def test_handshake_documents_include_backpressure(self):
# Create a `MongoClient` that is configured to record all handshake documents sent to the server as a part of
# connection establishment.
client = self.rs_or_single_client("mongodb://" + self.server.address_string)
# Send a `ping` command to the server and verify that the command succeeds. This ensure that a connection is
# established on all topologies. Note: MockupDB only supports standalone servers.
client.admin.command("ping")
# Assert that for every handshake document intercepted:
# the document has a field `backpressure` whose value is `true`.
self.assertEqual(self.handshake_req["backpressure"], True)
if __name__ == "__main__":
unittest.main()

View File

@ -257,7 +257,6 @@ class TestCollation(IntegrationTest):
self.assertEqual(
ja_collation.document["locale"], indexes["japanese_version"]["collation"]["locale"]
)
self.assertNotIn("collation", indexes["simple"])
self.db.test.drop_index("fieldname_1")
indexes = self.db.test.index_information()
self.assertIn("japanese_version", indexes)

183
test/test_daemon.py Normal file
View File

@ -0,0 +1,183 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the pymongo daemon module."""
from __future__ import annotations
import subprocess
import sys
import warnings
from unittest.mock import MagicMock, patch
sys.path[0:0] = [""]
from test import unittest
import pymongo.daemon as daemon_module
from pymongo.daemon import _popen_wait, _silence_resource_warning, _spawn_daemon
class TestPopenWait(unittest.TestCase):
def test_returns_returncode_on_success(self):
mock_popen = MagicMock()
mock_popen.wait.return_value = 0
self.assertEqual(0, _popen_wait(mock_popen, timeout=5))
mock_popen.wait.assert_called_once_with(timeout=5)
def test_returns_none_on_timeout_expired(self):
mock_popen = MagicMock()
mock_popen.wait.side_effect = subprocess.TimeoutExpired(cmd="foo", timeout=5)
self.assertIsNone(_popen_wait(mock_popen, timeout=5))
def test_none_timeout_passes_through(self):
mock_popen = MagicMock()
mock_popen.wait.return_value = 1
self.assertEqual(1, _popen_wait(mock_popen, timeout=None))
mock_popen.wait.assert_called_once_with(timeout=None)
class TestSilenceResourceWarning(unittest.TestCase):
def test_sets_returncode_to_zero(self):
mock_popen = MagicMock()
mock_popen.returncode = None
_silence_resource_warning(mock_popen)
self.assertEqual(0, mock_popen.returncode)
def test_no_op_for_none(self):
# Should not raise when popen is None (mongocryptd spawn failed).
_silence_resource_warning(None)
@unittest.skipIf(sys.platform == "win32", "Unix only")
class TestSpawnUnix(unittest.TestCase):
def setUp(self):
from pymongo.daemon import _spawn
self._spawn = _spawn
def test_returns_popen_on_success(self):
mock_popen = MagicMock()
with patch("subprocess.Popen", return_value=mock_popen):
result = self._spawn(["somecommand"])
self.assertIs(mock_popen, result)
def test_filenotfound_warns_and_returns_none(self):
with patch("subprocess.Popen", side_effect=FileNotFoundError("not found")):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = self._spawn(["nonexistent_command"])
self.assertIsNone(result)
self.assertEqual(1, len(w))
self.assertIs(RuntimeWarning, w[0].category)
self.assertIn("nonexistent_command", str(w[0].message))
@unittest.skipIf(sys.platform == "win32", "Unix only")
class TestSpawnDaemonDoublePopen(unittest.TestCase):
def setUp(self):
from pymongo.daemon import _spawn_daemon_double_popen
self._spawn_daemon_double_popen = _spawn_daemon_double_popen
def test_spawns_this_file_as_intermediate(self):
mock_popen = MagicMock()
mock_popen.wait.return_value = 0
with patch("subprocess.Popen", return_value=mock_popen) as mock_cls:
self._spawn_daemon_double_popen(["somecommand", "--arg"])
spawner_args = mock_cls.call_args[0][0]
self.assertEqual(sys.executable, spawner_args[0])
self.assertIn("daemon.py", spawner_args[1])
self.assertIn("somecommand", spawner_args)
def test_waits_for_intermediate_process(self):
mock_popen = MagicMock()
with patch("subprocess.Popen", return_value=mock_popen):
self._spawn_daemon_double_popen(["somecommand"])
mock_popen.wait.assert_called_once_with(timeout=daemon_module._WAIT_TIMEOUT)
def test_continues_on_timeout(self):
# _popen_wait swallows TimeoutExpired — double Popen must not raise.
mock_popen = MagicMock()
mock_popen.wait.side_effect = subprocess.TimeoutExpired(cmd="foo", timeout=10)
with patch("subprocess.Popen", return_value=mock_popen):
self._spawn_daemon_double_popen(["somecommand"]) # must not raise
@unittest.skipIf(sys.platform == "win32", "Unix only")
class TestSpawnDaemonUnix(unittest.TestCase):
def test_uses_double_popen_when_executable_set(self):
with patch("pymongo.daemon._spawn_daemon_double_popen") as mock_double:
_spawn_daemon(["somecommand"])
mock_double.assert_called_once_with(["somecommand"])
def test_fallback_to_spawn_when_no_executable(self):
with patch("pymongo.daemon._spawn") as mock_spawn:
with patch.object(sys, "executable", ""):
_spawn_daemon(["somecommand"])
mock_spawn.assert_called_once_with(["somecommand"])
@unittest.skipUnless(sys.platform == "win32", "Windows only")
class TestSpawnDaemonWindows(unittest.TestCase):
def test_silences_resource_warning_on_success(self):
mock_popen = MagicMock()
with patch("subprocess.Popen", return_value=mock_popen):
_spawn_daemon(["somecommand"])
self.assertEqual(0, mock_popen.returncode)
def test_filenotfound_warns(self):
with patch("subprocess.Popen", side_effect=FileNotFoundError("not found")):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
_spawn_daemon(["nonexistent_command"])
self.assertEqual(1, len(w))
self.assertIs(RuntimeWarning, w[0].category)
self.assertIn("nonexistent_command", str(w[0].message))
def test_uses_detached_process_flag(self):
# DETACHED_PROCESS must be passed so the child survives parent exit.
mock_popen = MagicMock()
with patch("subprocess.Popen", return_value=mock_popen) as mock_cls:
_spawn_daemon(["somecommand"])
kwargs = mock_cls.call_args[1]
self.assertEqual(daemon_module._DETACHED_PROCESS, kwargs["creationflags"])
def test_uses_devnull_for_stdio(self):
# stdin/stdout/stderr must be redirected to devnull to fully detach.
mock_popen = MagicMock()
with patch("subprocess.Popen", return_value=mock_popen) as mock_cls:
_spawn_daemon(["somecommand"])
kwargs = mock_cls.call_args[1]
self.assertIsNotNone(kwargs.get("stdin"))
self.assertIsNotNone(kwargs.get("stdout"))
self.assertIsNotNone(kwargs.get("stderr"))
def test_detached_process_constant_value(self):
# Value must match the Windows DETACHED_PROCESS process creation flag.
self.assertEqual(0x00000008, daemon_module._DETACHED_PROCESS)
@unittest.skipIf(sys.platform == "win32", "Unix only")
class TestMainBlock(unittest.TestCase):
def test_exits_with_zero(self):
# Run daemon.py as a script with a no-op subprocess; verify it exits cleanly.
result = subprocess.run(
[sys.executable, "-m", "pymongo.daemon", sys.executable, "-c", "pass"],
timeout=15,
)
self.assertEqual(0, result.returncode)
if __name__ == "__main__":
unittest.main()

View File

@ -25,7 +25,9 @@ from asyncio import StreamReader, StreamWriter
from pathlib import Path
from test.helpers import ConcurrentRunner
from test.utils import flaky
from test.utils_shared import delay
from pymongo.errors import ConnectionFailure
from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector
from pymongo.synchronous.pool import Connection
@ -67,7 +69,12 @@ from pymongo.errors import (
)
from pymongo.hello import Hello, HelloCompat
from pymongo.helpers_shared import _check_command_response, _check_write_command_response
from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent
from pymongo.monitoring import (
ConnectionCheckOutFailedEvent,
PoolClearedEvent,
ServerHeartbeatFailedEvent,
ServerHeartbeatStartedEvent,
)
from pymongo.server_description import SERVER_TYPE, ServerDescription
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology, _ErrorContext
@ -131,6 +138,9 @@ def got_app_error(topology, app_error):
raise AssertionError
except (AutoReconnect, NotPrimaryError, OperationFailure) as e:
if when == "beforeHandshakeCompletes":
# The pool would have added the SystemOverloadedError in this case.
if isinstance(e, AutoReconnect):
e._add_error_label("SystemOverloadedError")
completed_handshake = False
elif when == "afterHandshakeCompletes":
completed_handshake = True
@ -437,6 +447,57 @@ class TestPoolManagement(IntegrationTest):
Connection.close_conn = original_close
class TestPoolBackpressure(IntegrationTest):
@client_context.require_version_min(7, 0, 0)
def test_connection_pool_is_not_cleared(self):
listener = CMAPListener()
# Create a client that listens to CMAP events, with maxConnecting=100.
client = self.rs_or_single_client(maxConnecting=100, event_listeners=[listener])
# Enable the ingress rate limiter.
client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=True
)
client.admin.command("setParameter", 1, ingressConnectionEstablishmentRatePerSec=20)
client.admin.command("setParameter", 1, ingressConnectionEstablishmentBurstCapacitySecs=1)
client.admin.command("setParameter", 1, ingressConnectionEstablishmentMaxQueueDepth=1)
# Disable the ingress rate limiter on teardown.
# Sleep for 1 second before disabling to avoid the rate limiter.
def teardown():
time.sleep(1)
client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=False
)
self.addCleanup(teardown)
# Make sure the collection has at least one document.
client.test.test.delete_many({})
client.test.test.insert_one({})
# Run a slow operation to tie up the connection.
def target():
try:
client.test.test.find_one({"$where": delay(0.1)})
except ConnectionFailure:
pass
# Run 100 parallel operations that contend for connections.
tasks = []
for _ in range(100):
tasks.append(ConcurrentRunner(target=target))
for t in tasks:
t.start()
for t in tasks:
t.join()
# Verify there were at least 10 connection checkout failed event but no pool cleared events.
self.assertGreater(len(listener.events_by_type(ConnectionCheckOutFailedEvent)), 10)
self.assertEqual(len(listener.events_by_type(PoolClearedEvent)), 0)
class TestServerMonitoringMode(IntegrationTest):
@client_context.require_no_load_balancer
def setUp(self):

View File

@ -872,8 +872,6 @@ class TestViews(EncryptionIntegrationTest):
class TestCorpus(EncryptionIntegrationTest):
# PYTHON-5708: Encryption tests sending large payloads fail on some mongocryptd versions.
@client_context.require_version_max(6, 99)
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
def setUp(self):
super().setUp()
@ -1050,8 +1048,6 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
client_encrypted: MongoClient
listener: OvertCommandListener
# PYTHON-5708: Encryption tests sending large payloads fail on some mongocryptd versions.
@client_context.require_version_max(6, 99)
def setUp(self):
super().setUp()
db = client_context.client.db
@ -3308,6 +3304,7 @@ class TestAutomaticDecryptionKeys(EncryptionIntegrationTest):
class TestExplicitTextEncryptionProse(EncryptionIntegrationTest):
@client_context.require_no_standalone
@client_context.require_version_min(8, 2, -1)
@client_context.require_version_max(8, 99, 99)
@client_context.require_libmongocrypt_min(1, 15, 1)
@client_context.require_pymongocrypt_min(1, 16, 0)
def setUp(self):

374
test/test_event_loggers.py Normal file
View File

@ -0,0 +1,374 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for pymongo.event_loggers."""
from __future__ import annotations
import sys
from unittest.mock import MagicMock, patch
sys.path[0:0] = [""]
from test import unittest
from pymongo.event_loggers import (
CommandLogger,
ConnectionPoolLogger,
HeartbeatLogger,
ServerLogger,
TopologyLogger,
)
class TestCommandLogger(unittest.TestCase):
def setUp(self):
self.logger = CommandLogger()
def test_started_logs_info(self):
event = MagicMock()
event.command_name = "find"
event.request_id = 42
event.connection_id = ("localhost", 27017)
with self.assertLogs(level="INFO") as logs:
self.logger.started(event)
log = logs.records[0].getMessage()
self.assertIn("find", log)
self.assertIn("42", log)
self.assertIn("started", log)
def test_succeeded_logs_info(self):
event = MagicMock()
event.command_name = "insert"
event.request_id = 7
event.connection_id = ("localhost", 27017)
event.duration_micros = 500
with self.assertLogs(level="INFO") as logs:
self.logger.succeeded(event)
log = logs.records[0].getMessage()
self.assertIn("insert", log)
self.assertIn("7", log)
self.assertIn("500", log)
self.assertIn("microseconds", log)
self.assertIn("succeeded", log)
def test_failed_logs_info(self):
event = MagicMock()
event.command_name = "delete"
event.request_id = 3
event.connection_id = ("localhost", 27017)
event.duration_micros = 300
with self.assertLogs(level="INFO") as logs:
self.logger.failed(event)
log = logs.records[0].getMessage()
self.assertIn("delete", log)
self.assertIn("3", log)
self.assertIn("300", log)
self.assertIn("microseconds", log)
self.assertIn("failed", log)
class TestServerLogger(unittest.TestCase):
def setUp(self):
self.logger = ServerLogger()
def test_opened_logs_info(self):
event = MagicMock()
event.server_address = ("host1", 27017)
event.topology_id = "topology-abc"
with self.assertLogs(level="INFO") as logs:
self.logger.opened(event)
log = logs.records[0].getMessage()
self.assertIn("host1", log)
self.assertIn("topology-abc", log)
def test_closed_logs_warning(self):
event = MagicMock()
event.server_address = ("host1", 27017)
event.topology_id = "topology-abc"
with self.assertLogs(level="WARNING") as logs:
self.logger.closed(event)
log = logs.records[0].getMessage()
self.assertIn("host1", log)
self.assertIn("topology-abc", log)
def test_description_changed_logs_when_type_changes(self):
event = MagicMock()
event.server_address = ("host1", 27017)
event.previous_description.server_type = 1
event.previous_description.server_type_name = "Unknown"
event.new_description.server_type = 2
event.new_description.server_type_name = "Standalone"
with self.assertLogs(level="INFO") as logs:
self.logger.description_changed(event)
log = logs.records[0].getMessage()
self.assertIn("Unknown", log)
self.assertIn("Standalone", log)
def test_description_changed_no_log_when_type_same(self):
event = MagicMock()
event.previous_description.server_type = 2
event.new_description.server_type = 2
with patch("logging.info") as mock_info:
self.logger.description_changed(event)
mock_info.assert_not_called()
class TestHeartbeatLogger(unittest.TestCase):
def setUp(self):
self.logger = HeartbeatLogger()
def test_started_logs_info(self):
event = MagicMock()
event.connection_id = ("mongo.host", 27017)
with self.assertLogs(level="INFO") as logs:
self.logger.started(event)
log = logs.records[0].getMessage()
self.assertIn("mongo.host", log)
def test_succeeded_logs_info(self):
event = MagicMock()
event.connection_id = ("mongo.host", 27017)
event.reply.document = {"ok": 1, "maxWireVersion": 17}
with self.assertLogs(level="INFO") as logs:
self.logger.succeeded(event)
log = logs.records[0].getMessage()
self.assertIn("mongo.host", log)
self.assertIn("succeeded", log)
self.assertIn("maxWireVersion", log)
def test_failed_logs_warning(self):
event = MagicMock()
event.connection_id = ("mongo.host", 27017)
event.reply = TimeoutError("timed out")
with self.assertLogs(level="WARNING") as logs:
self.logger.failed(event)
log = logs.records[0].getMessage()
self.assertIn("mongo.host", log)
self.assertIn("failed", log)
self.assertIn("timed out", log)
class TestTopologyLogger(unittest.TestCase):
def setUp(self):
self.logger = TopologyLogger()
def test_opened_logs_info(self):
event = MagicMock()
event.topology_id = "topo-1"
with self.assertLogs(level="INFO") as logs:
self.logger.opened(event)
log = logs.records[0].getMessage()
self.assertIn("topo-1", log)
self.assertIn("opened", log)
def test_closed_logs_info(self):
event = MagicMock()
event.topology_id = "topo-1"
with self.assertLogs(level="INFO") as logs:
self.logger.closed(event)
log = logs.records[0].getMessage()
self.assertIn("topo-1", log)
self.assertIn("closed", log)
def test_description_changed_always_logs_update(self):
event = MagicMock()
event.topology_id = "topo-1"
event.previous_description.topology_type = 1
event.new_description.topology_type = 1
event.new_description.has_writable_server.return_value = True
event.new_description.has_readable_server.return_value = True
with self.assertLogs(level="INFO") as logs:
self.logger.description_changed(event)
messages = [r.getMessage() for r in logs.records]
self.assertTrue(any("updated" in m for m in messages))
self.assertTrue(any("topo-1" in m for m in messages))
def test_description_changed_logs_type_change(self):
event = MagicMock()
event.topology_id = "topo-2"
event.previous_description.topology_type = 0
event.previous_description.topology_type_name = "Unknown"
event.new_description.topology_type = 1
event.new_description.topology_type_name = "Single"
event.new_description.has_writable_server.return_value = True
event.new_description.has_readable_server.return_value = True
with self.assertLogs(level="INFO") as logs:
self.logger.description_changed(event)
messages = [r.getMessage() for r in logs.records]
self.assertTrue(any("Unknown" in m and "Single" in m for m in messages))
def test_description_changed_no_type_change_log_when_same(self):
event = MagicMock()
event.topology_id = "topo-1"
event.previous_description.topology_type = 1
event.new_description.topology_type = 1
event.new_description.has_writable_server.return_value = True
event.new_description.has_readable_server.return_value = True
with self.assertLogs(level="INFO") as logs:
self.logger.description_changed(event)
messages = [r.getMessage() for r in logs.records]
self.assertFalse(any("changed type" in m for m in messages))
def test_description_changed_warns_no_writable_server(self):
event = MagicMock()
event.previous_description.topology_type = 1
event.new_description.topology_type = 1
event.new_description.has_writable_server.return_value = False
event.new_description.has_readable_server.return_value = True
with self.assertLogs(level="WARNING") as logs:
self.logger.description_changed(event)
messages = [r.getMessage() for r in logs.records]
self.assertTrue(any("writable" in m for m in messages))
def test_description_changed_warns_no_readable_server(self):
event = MagicMock()
event.previous_description.topology_type = 1
event.new_description.topology_type = 1
event.new_description.has_writable_server.return_value = True
event.new_description.has_readable_server.return_value = False
with self.assertLogs(level="WARNING") as logs:
self.logger.description_changed(event)
messages = [r.getMessage() for r in logs.records]
self.assertTrue(any("readable" in m for m in messages))
def test_description_changed_warns_both_unavailable(self):
event = MagicMock()
event.previous_description.topology_type = 1
event.new_description.topology_type = 1
event.new_description.has_writable_server.return_value = False
event.new_description.has_readable_server.return_value = False
with self.assertLogs(level="WARNING") as logs:
self.logger.description_changed(event)
warning_messages = [r.getMessage() for r in logs.records if r.levelname == "WARNING"]
self.assertEqual(len(warning_messages), 2)
class TestConnectionPoolLogger(unittest.TestCase):
def setUp(self):
self.logger = ConnectionPoolLogger()
def test_pool_created(self):
event = MagicMock()
event.address = ("localhost", 27017)
with self.assertLogs(level="INFO") as logs:
self.logger.pool_created(event)
log = logs.records[0].getMessage()
self.assertIn("pool created", log)
self.assertIn("localhost", log)
def test_pool_ready(self):
event = MagicMock()
event.address = ("localhost", 27017)
with self.assertLogs(level="INFO") as logs:
self.logger.pool_ready(event)
log = logs.records[0].getMessage()
self.assertIn("pool ready", log)
self.assertIn("localhost", log)
def test_pool_cleared(self):
event = MagicMock()
event.address = ("localhost", 27017)
with self.assertLogs(level="INFO") as logs:
self.logger.pool_cleared(event)
log = logs.records[0].getMessage()
self.assertIn("pool cleared", log)
self.assertIn("localhost", log)
def test_pool_closed(self):
event = MagicMock()
event.address = ("localhost", 27017)
with self.assertLogs(level="INFO") as logs:
self.logger.pool_closed(event)
log = logs.records[0].getMessage()
self.assertIn("pool closed", log)
self.assertIn("localhost", log)
def test_connection_created(self):
event = MagicMock()
event.address = ("localhost", 27017)
event.connection_id = 5
with self.assertLogs(level="INFO") as logs:
self.logger.connection_created(event)
log = logs.records[0].getMessage()
self.assertIn("connection created", log)
self.assertIn("5", log)
self.assertIn("localhost", log)
def test_connection_ready(self):
event = MagicMock()
event.address = ("localhost", 27017)
event.connection_id = 5
with self.assertLogs(level="INFO") as logs:
self.logger.connection_ready(event)
log = logs.records[0].getMessage()
self.assertIn("connection setup succeeded", log)
self.assertIn("5", log)
def test_connection_closed(self):
event = MagicMock()
event.address = ("localhost", 27017)
event.connection_id = 5
event.reason = "stale"
with self.assertLogs(level="INFO") as logs:
self.logger.connection_closed(event)
log = logs.records[0].getMessage()
self.assertIn("connection closed", log)
self.assertIn("5", log)
self.assertIn("stale", log)
def test_connection_check_out_started(self):
event = MagicMock()
event.address = ("localhost", 27017)
with self.assertLogs(level="INFO") as logs:
self.logger.connection_check_out_started(event)
log = logs.records[0].getMessage()
self.assertIn("check out started", log)
self.assertIn("localhost", log)
def test_connection_check_out_failed(self):
event = MagicMock()
event.address = ("localhost", 27017)
event.reason = "timeout"
with self.assertLogs(level="INFO") as logs:
self.logger.connection_check_out_failed(event)
log = logs.records[0].getMessage()
self.assertIn("check out failed", log)
self.assertIn("timeout", log)
self.assertIn("localhost", log)
def test_connection_checked_out(self):
event = MagicMock()
event.address = ("localhost", 27017)
event.connection_id = 3
with self.assertLogs(level="INFO") as logs:
self.logger.connection_checked_out(event)
log = logs.records[0].getMessage()
self.assertIn("checked out", log)
self.assertIn("3", log)
self.assertIn("localhost", log)
def test_connection_checked_in(self):
event = MagicMock()
event.address = ("localhost", 27017)
event.connection_id = 3
with self.assertLogs(level="INFO") as logs:
self.logger.connection_checked_in(event)
log = logs.records[0].getMessage()
self.assertIn("checked into", log)
self.assertIn("3", log)
self.assertIn("localhost", log)
if __name__ == "__main__":
unittest.main()

116
test/test_gcp_helpers.py Normal file
View File

@ -0,0 +1,116 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for pymongo/_gcp_helpers.py."""
from __future__ import annotations
import sys
import unittest
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
sys.path[0:0] = [""]
from pymongo._gcp_helpers import _get_gcp_response
@contextmanager
def _mock_urlopen(status: int, body: str):
"""Context manager that patches ``urllib.request.urlopen`` with a fake response."""
mock_response = MagicMock()
mock_response.__enter__ = MagicMock(return_value=mock_response)
mock_response.__exit__ = MagicMock(return_value=False)
mock_response.status = status
mock_response.read.return_value = body.encode("utf8")
with patch("urllib.request.urlopen", return_value=mock_response) as mock_open:
yield mock_open
class TestGetGcpResponse(unittest.TestCase):
"""Tests for :func:`pymongo._gcp_helpers._get_gcp_response`."""
def test_successful_response_returns_access_token(self):
"""A 200 response yields ``{"access_token": <body>}``."""
token = "ya29.some-gcp-token"
with _mock_urlopen(200, token):
result = _get_gcp_response("https://example.com")
self.assertEqual(result, {"access_token": token})
def test_non_200_status_raises_value_error(self):
"""A non-200 HTTP status raises :class:`ValueError`."""
for status in (400, 401, 403, 500, 503):
with self.subTest(status=status):
with _mock_urlopen(status, "error"):
with self.assertRaises(ValueError) as ctx:
_get_gcp_response("https://example.com")
self.assertIn("Failed to acquire IMDS access token", str(ctx.exception))
def test_urlopen_exception_raises_value_error(self):
"""An exception from ``urlopen`` is wrapped in :class:`ValueError`."""
with patch("urllib.request.urlopen", side_effect=OSError("connection refused")):
with self.assertRaises(ValueError) as ctx:
_get_gcp_response("https://example.com")
self.assertIn("Failed to acquire IMDS access token", str(ctx.exception))
self.assertIn("connection refused", str(ctx.exception))
def test_url_contains_resource_as_audience(self):
"""The ``resource`` argument is appended as ``?audience=`` in the URL."""
resource = "https://my-service.example.com"
with _mock_urlopen(200, "token") as mock_open:
_get_gcp_response(resource)
request_obj = mock_open.call_args[0][0]
self.assertIn(f"?audience={resource}", request_obj.full_url)
def test_request_has_metadata_flavor_google_header(self):
"""The request must include the ``Metadata-Flavor: Google`` header."""
with _mock_urlopen(200, "token") as mock_open:
_get_gcp_response("https://example.com")
request_obj = mock_open.call_args[0][0]
self.assertEqual(request_obj.get_header("Metadata-flavor"), "Google")
def test_default_timeout_is_five_seconds(self):
"""Without an explicit timeout, ``urlopen`` is called with ``timeout=5``."""
with _mock_urlopen(200, "token") as mock_open:
_get_gcp_response("https://example.com")
_, kwargs = mock_open.call_args
self.assertEqual(kwargs.get("timeout"), 5)
def test_custom_timeout_is_forwarded(self):
"""An explicit ``timeout`` value is passed through to ``urlopen``."""
with _mock_urlopen(200, "token") as mock_open:
_get_gcp_response("https://example.com", timeout=30)
_, kwargs = mock_open.call_args
self.assertEqual(kwargs.get("timeout"), 30)
def test_urlopen_exception_does_not_chain_original(self):
"""The raised ``ValueError`` suppresses the original exception (``from None``)."""
with patch("urllib.request.urlopen", side_effect=RuntimeError("network error")):
with self.assertRaises(ValueError) as ctx:
_get_gcp_response("https://example.com")
# ``raise ... from None`` sets __cause__ to None and __suppress_context__ to True.
self.assertIs(ctx.exception.__cause__, None)
self.assertIs(ctx.exception.__suppress_context__, True)
if __name__ == "__main__":
unittest.main()

View File

@ -511,6 +511,39 @@ class TestPooling(_TestPoolingBase):
str(error.exception),
)
@client_context.require_failCommand_appName
def test_pool_backpressure_preserves_existing_connections(self):
client = self.rs_or_single_client()
coll = client.pymongo_test.t
pool = get_pool(client)
coll.insert_many([{"x": 1} for _ in range(10)])
t = SocketGetter(self.c, pool)
t.start()
while t.state != "connection":
time.sleep(0.1)
assert not t.sock.conn_closed()
# Mock a session establishment overload.
mock_connection_fail = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"closeConnection": True,
},
}
with self.fail_point(mock_connection_fail):
coll.find_one({})
# Make sure the existing socket was not affected.
assert not t.sock.conn_closed()
# Cleanup
t.release_conn()
t.join()
pool.close()
class TestPoolMaxSize(_TestPoolingBase):
def test_max_pool_size(self):

View File

@ -0,0 +1,271 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for pyopenssl_context.py.
These tests require PyOpenSSL (install via: pip install pymongo[ocsp]).
Tests are automatically skipped when PyOpenSSL is not available.
"""
from __future__ import annotations
import ssl
import sys
from unittest.mock import patch
sys.path[0:0] = [""]
from test import unittest
try:
from pymongo import pyopenssl_context as _ctx_module
from pymongo.pyopenssl_context import (
PROTOCOL_SSLv23,
SSLContext,
_is_ip_address,
_ragged_eof,
)
_HAVE_PYOPENSSL = True
except ImportError:
_HAVE_PYOPENSSL = False
# ---------------------------------------------------------------------------
# Pure functions (no SSL context required)
# ---------------------------------------------------------------------------
class TestIsIpAddress(unittest.TestCase):
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_ipv4(self):
self.assertTrue(_is_ip_address("192.168.1.1"))
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_ipv6(self):
self.assertTrue(_is_ip_address("::1"))
self.assertTrue(_is_ip_address("2001:db8::1"))
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_hostname_is_not_ip(self):
self.assertFalse(_is_ip_address("example.com"))
self.assertFalse(_is_ip_address("localhost"))
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_invalid_string_returns_false(self):
self.assertFalse(_is_ip_address("not-an-ip"))
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_unicode_error_returns_false(self):
# UnicodeError path: some inputs that can't be decoded.
# ip_address raises UnicodeError for byte strings with non-ASCII.
self.assertFalse(_is_ip_address(b"\xff\xfe"))
class TestRaggedEof(unittest.TestCase):
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_matching_args_returns_true(self):
from OpenSSL.SSL import SysCallError
exc = SysCallError(-1, "Unexpected EOF")
self.assertTrue(_ragged_eof(exc))
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_non_matching_args_returns_false(self):
from OpenSSL.SSL import SysCallError
exc = SysCallError(0, "something else")
self.assertFalse(_ragged_eof(exc))
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_wrong_code_returns_false(self):
from OpenSSL.SSL import SysCallError
exc = SysCallError(5, "Unexpected EOF")
self.assertFalse(_ragged_eof(exc))
# ---------------------------------------------------------------------------
# SSLContext — construction and properties
# ---------------------------------------------------------------------------
class TestSSLContextConstruction(unittest.TestCase):
def _make(self):
return SSLContext(PROTOCOL_SSLv23)
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_protocol_property(self):
ctx = self._make()
self.assertEqual(ctx.protocol, PROTOCOL_SSLv23)
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_default_check_hostname(self):
ctx = self._make()
self.assertTrue(ctx.check_hostname)
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_set_check_hostname_false(self):
ctx = self._make()
ctx.check_hostname = False
self.assertFalse(ctx.check_hostname)
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_set_check_hostname_invalid_raises(self):
ctx = self._make()
with self.assertRaises(TypeError):
ctx.check_hostname = "yes"
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_default_check_ocsp_endpoint(self):
ctx = self._make()
self.assertTrue(ctx.check_ocsp_endpoint)
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_set_check_ocsp_endpoint_false(self):
ctx = self._make()
ctx.check_ocsp_endpoint = False
self.assertFalse(ctx.check_ocsp_endpoint)
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_verify_mode_roundtrip(self):
ctx = self._make()
ctx.verify_mode = ssl.CERT_REQUIRED
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_verify_mode_cert_none(self):
ctx = self._make()
ctx.verify_mode = ssl.CERT_NONE
self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_options_setter_and_getter(self):
ctx = self._make()
from pymongo.pyopenssl_context import OP_NO_SSLv3
ctx.options = OP_NO_SSLv3
self.assertTrue(ctx.options & OP_NO_SSLv3)
# ---------------------------------------------------------------------------
# SSLContext._load_certifi
# ---------------------------------------------------------------------------
class TestLoadCertifi(unittest.TestCase):
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_raises_when_certifi_unavailable(self):
from pymongo.errors import ConfigurationError
ctx = SSLContext(PROTOCOL_SSLv23)
with patch.object(_ctx_module, "_HAVE_CERTIFI", False):
with self.assertRaises(ConfigurationError) as exc_ctx:
ctx._load_certifi()
self.assertIn("certifi", str(exc_ctx.exception))
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_loads_when_certifi_available(self):
if not _ctx_module._HAVE_CERTIFI:
self.skipTest("certifi not installed")
ctx = SSLContext(PROTOCOL_SSLv23)
ctx.verify_mode = ssl.CERT_NONE
# Should not raise.
ctx._load_certifi()
# ---------------------------------------------------------------------------
# SSLContext.load_default_certs — platform branching
# ---------------------------------------------------------------------------
class TestLoadDefaultCerts(unittest.TestCase):
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_darwin_calls_load_certifi(self):
with patch.object(_ctx_module._sys, "platform", "darwin"):
with patch.object(SSLContext, "_load_certifi") as mock_certifi:
with patch("OpenSSL.SSL.Context.set_default_verify_paths"):
ctx = SSLContext(PROTOCOL_SSLv23)
ctx.load_default_certs()
mock_certifi.assert_called()
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_win32_calls_load_wincerts(self):
with patch.object(_ctx_module._sys, "platform", "win32"):
with patch.object(SSLContext, "_load_wincerts") as mock_wincerts:
with patch("OpenSSL.SSL.Context.set_default_verify_paths"):
ctx = SSLContext(PROTOCOL_SSLv23)
ctx.load_default_certs()
calls = [call.args[0] for call in mock_wincerts.call_args_list]
self.assertIn("CA", calls)
self.assertIn("ROOT", calls)
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_win32_falls_back_to_certifi_on_exception(self):
with patch.object(_ctx_module._sys, "platform", "win32"):
with patch.object(SSLContext, "_load_wincerts", side_effect=Exception("no certs")):
with patch.object(SSLContext, "_load_certifi") as mock_certifi:
with patch("OpenSSL.SSL.Context.set_default_verify_paths"):
ctx = SSLContext(PROTOCOL_SSLv23)
ctx.load_default_certs()
mock_certifi.assert_called()
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_linux_no_certifi_call(self):
with patch.object(_ctx_module._sys, "platform", "linux"):
with patch.object(SSLContext, "_load_certifi") as mock_certifi:
with patch("OpenSSL.SSL.Context.set_default_verify_paths"):
ctx = SSLContext(PROTOCOL_SSLv23)
ctx.load_default_certs()
mock_certifi.assert_not_called()
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_calls_set_default_verify_paths(self):
with patch.object(_ctx_module._sys, "platform", "linux"):
ctx = SSLContext(PROTOCOL_SSLv23)
with patch.object(ctx._ctx, "set_default_verify_paths") as mock_sdvp:
ctx.load_default_certs()
mock_sdvp.assert_called_once()
# ---------------------------------------------------------------------------
# SSLContext.set_default_verify_paths
# ---------------------------------------------------------------------------
class TestSetDefaultVerifyPaths(unittest.TestCase):
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_delegates_to_ctx(self):
ctx = SSLContext(PROTOCOL_SSLv23)
with patch.object(ctx._ctx, "set_default_verify_paths") as mock_sdvp:
ctx.set_default_verify_paths()
mock_sdvp.assert_called_once()
# ---------------------------------------------------------------------------
# SSLContext.load_verify_locations
# ---------------------------------------------------------------------------
class TestLoadVerifyLocations(unittest.TestCase):
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_delegates_to_ctx(self):
ctx = SSLContext(PROTOCOL_SSLv23)
with patch.object(ctx._ctx, "load_verify_locations") as mock_lvl:
ctx.load_verify_locations(cafile="/tmp/ca.pem")
mock_lvl.assert_called_once_with("/tmp/ca.pem", None)
if __name__ == "__main__":
unittest.main()

View File

@ -19,9 +19,12 @@ import os
import pprint
import sys
import threading
from test.utils import set_fail_point
from test.utils import ensure_all_connected, set_fail_point
from unittest import mock
from pymongo.errors import OperationFailure
from pymongo import MongoClient
from pymongo.common import MAX_ADAPTIVE_RETRIES
from pymongo.errors import OperationFailure, PyMongoError
sys.path[0:0] = [""]
@ -38,6 +41,7 @@ from test.utils_shared import (
)
from pymongo.monitoring import (
CommandFailedEvent,
ConnectionCheckedOutEvent,
ConnectionCheckOutFailedEvent,
ConnectionCheckOutFailedReason,
@ -145,6 +149,19 @@ class TestPoolPausedError(IntegrationTest):
class TestRetryableReads(IntegrationTest):
def setUp(self) -> None:
super().setUp()
self.setup_client = MongoClient(**client_context.client_options)
self.addCleanup(self.setup_client.close)
# TODO: After PYTHON-4595 we can use async event handlers and remove this workaround.
def configure_fail_point_sync(self, command_args, off=False) -> None:
cmd = {"configureFailPoint": "failCommand", **command_args}
if off:
cmd["mode"] = "off"
cmd.pop("data", None)
self.setup_client.admin.command(cmd)
@client_context.require_multiple_mongoses
@client_context.require_failCommand_fail_point
def test_retryable_reads_are_retried_on_a_different_mongos_when_one_is_available(self):
@ -263,16 +280,22 @@ class TestRetryableReads(IntegrationTest):
@client_context.require_secondaries_count(1)
@client_context.require_failCommand_fail_point
@client_context.require_version_min(4, 4, 0)
def test_03_01_retryable_reads_caused_by_overload_errors_are_retried_on_a_different_replicaset_server_when_one_is_available(
def test_03_01_retryable_reads_caused_by_overload_errors_are_retried_on_a_different_replicaset_server_when_one_is_available_and_overload_retargeting_is_enabled(
self
):
listener = OvertCommandListener()
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled.
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, `enableOverloadRetargeting=True`, and command event monitoring enabled.
client = self.rs_or_single_client(
event_listeners=[listener], retryReads=True, readPreference="primaryPreferred"
event_listeners=[listener],
retryReads=True,
readPreference="primaryPreferred",
enableOverloadRetargeting=True,
)
# Ensure the client has discovered all nodes.
ensure_all_connected(client)
# 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
@ -312,6 +335,9 @@ class TestRetryableReads(IntegrationTest):
event_listeners=[listener], retryReads=True, readPreference="primaryPreferred"
)
# Ensure the client has discovered all nodes.
ensure_all_connected(client)
# 2. Configure a fail point with the RetryableError error label.
command_args = {
"configureFailPoint": "failCommand",
@ -337,6 +363,161 @@ class TestRetryableReads(IntegrationTest):
# 6. Assert that both events occurred the same server.
assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id
@client_context.require_replica_set
@client_context.require_secondaries_count(1)
@client_context.require_failCommand_fail_point
@client_context.require_version_min(4, 4, 0)
def test_03_03_retryable_reads_caused_by_overload_errors_are_retried_on_the_same_replicaset_server_when_one_is_available_and_overload_retargeting_is_disabled(
self
):
listener = OvertCommandListener()
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled.
client = self.rs_or_single_client(
event_listeners=[listener],
retryReads=True,
readPreference="primaryPreferred",
)
# Ensure the client has discovered all nodes.
ensure_all_connected(client)
# 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 6,
},
}
set_fail_point(client, command_args)
# 3. Reset the command event monitor to clear the fail point command from its stored events.
listener.reset()
# 4. Execute a `find` command with `client`.
client.t.t.find_one({})
# 5. Assert that one failed command event and one successful command event occurred.
self.assertEqual(len(listener.failed_events), 1)
self.assertEqual(len(listener.succeeded_events), 1)
# 6. Assert that both events occurred on the same server.
assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id
@client_context.require_failCommand_fail_point
@client_context.require_version_min(4, 4, 0) # type:ignore[untyped-decorator]
def test_overload_then_nonoverload_retries_increased_reads(self) -> None:
# Create a client.
listener = OvertCommandListener()
# Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error
# code `91` (ShutdownInProgress) and `RetryableError` and `SystemOverloadedError` labels.
overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 91,
},
}
# Configure a fail point with error code `91` (ShutdownInProgress) with only the `RetryableError` error label.
non_overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 91,
"errorLabels": ["RetryableError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the fail point command only if the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(non_overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = self.rs_client(event_listeners=[listener])
client.test.test.insert_one({})
self.configure_fail_point_sync(overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
with self.assertRaises(PyMongoError):
client.test.test.find_one()
started_finds = [e for e in listener.started_events if e.command_name == "find"]
self.assertEqual(len(started_finds), MAX_ADAPTIVE_RETRIES + 1)
@client_context.require_failCommand_fail_point
@client_context.require_version_min(4, 4, 0) # type:ignore[untyped-decorator]
def test_backoff_is_not_applied_for_non_overload_errors(self):
if _IS_SYNC:
mock_target = "pymongo.synchronous.helpers._RetryPolicy.backoff"
else:
mock_target = "pymongo.helpers._RetryPolicy.backoff"
# Create a client.
listener = OvertCommandListener()
# Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error
# code `91` (ShutdownInProgress) and `RetryableError` and `SystemOverloadedError` labels.
overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 91,
},
}
# Configure a fail point with error code `91` (ShutdownInProgress) with only the `RetryableError` error label.
non_overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 91,
"errorLabels": ["RetryableError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the fail point command only if the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(non_overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = self.rs_client(event_listeners=[listener])
client.test.test.insert_one({})
self.configure_fail_point_sync(overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Perform a findOne operation with coll. Expect the operation to fail.
with mock.patch(mock_target, return_value=0) as mock_backoff:
with self.assertRaises(PyMongoError):
client.test.test.find_one()
# Assert that backoff was applied only once for the initial overload error and not for the subsequent non-overload retryable errors.
self.assertEqual(mock_backoff.call_count, 1)
if __name__ == "__main__":
unittest.main()

View File

@ -21,6 +21,9 @@ import pprint
import sys
import threading
from test.utils import flaky, set_fail_point
from unittest import mock
from pymongo.common import MAX_ADAPTIVE_RETRIES
sys.path[0:0] = [""]
@ -43,14 +46,17 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS
from bson.int64 import Int64
from bson.raw_bson import RawBSONDocument
from bson.son import SON
from pymongo import MongoClient
from pymongo.errors import (
AutoReconnect,
ConnectionFailure,
OperationFailure,
NotPrimaryError,
PyMongoError,
ServerSelectionTimeoutError,
WriteConcernError,
)
from pymongo.monitoring import (
CommandFailedEvent,
CommandSucceededEvent,
ConnectionCheckedOutEvent,
ConnectionCheckOutFailedEvent,
@ -597,5 +603,291 @@ class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest):
self.assertEqual(sent_txn_id, final_txn_id, msg)
class TestErrorPropagationAfterEncounteringMultipleErrors(IntegrationTest):
# Only run against replica sets as mongos does not propagate the NoWritesPerformed label to the drivers.
@client_context.require_replica_set
# Run against server versions 6.0 and above.
@client_context.require_version_min(6, 0) # type: ignore[untyped-decorator]
def setUp(self) -> None:
super().setUp()
self.setup_client = MongoClient(**client_context.default_client_options)
self.addCleanup(self.setup_client.close)
# TODO: After PYTHON-4595 we can use async event handlers and remove this workaround.
def configure_fail_point_sync(self, command_args, off=False) -> None:
cmd = {"configureFailPoint": "failCommand"}
cmd.update(command_args)
if off:
cmd["mode"] = "off"
cmd.pop("data", None)
self.setup_client.admin.command(cmd)
def test_01_drivers_return_the_correct_error_when_receiving_only_errors_without_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 91,
},
}
# Via the command monitoring CommandFailedEvent, configure a fail point with error code 10107 (NotWritablePrimary).
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 10107,
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = self.rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(NotPrimaryError) as exc:
client.test.test.insert_one({})
# Assert that the error code of the server error is 10107.
assert exc.exception.errors["code"] == 10107 # type:ignore[call-overload]
def test_02_drivers_return_the_correct_error_when_receiving_only_errors_with_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
"errorCode": 91,
},
}
# Via the command monitoring CommandFailedEvent, configure a fail point with error code `10107` (NotWritablePrimary)
# and a NoWritesPerformed label.
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 10107,
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
},
}
def failed(event: CommandFailedEvent) -> None:
if listener.failed_events:
return
# Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2.
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = self.rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(NotPrimaryError) as exc:
client.test.test.insert_one({})
# Assert that the error code of the server error is 91.
assert exc.exception.errors["code"] == 91 # type:ignore[call-overload]
def test_03_drivers_return_the_correct_error_when_receiving_some_errors_with_NoWritesPerformed_and_some_without_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error
# code `91` (NotWritablePrimary) and the `NoWritesPerformed`, `RetryableError` and `SystemOverloadedError` labels.
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
"errorCode": 91,
},
}
# Configure a fail point with error code `91` (ShutdownInProgress) with the `RetryableError` and
# `SystemOverloadedError` error labels but without the `NoWritesPerformed` error label.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorCode": 91,
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the fail point command only if the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = self.rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(PyMongoError) as exc:
client.test.test.insert_one({})
# Assert that the error code of the server error is 91.
assert exc.exception.errors["code"] == 91
# Assert that the error does not contain the error label `NoWritesPerformed`.
assert "NoWritesPerformed" not in exc.exception.errors["errorLabels"]
def test_overload_then_nonoverload_retries_increased_writes(self) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error
# code `91` (ShutdownInProgress) and `RetryableError` and `SystemOverloadedError` labels.
overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 91,
},
}
# Configure a fail point with error code `91` (ShutdownInProgress) with the `RetryableError` and `RetryableWriteError` error labels.
non_overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 91,
"errorLabels": ["RetryableError", "RetryableWriteError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the fail point command only if the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(non_overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = self.rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
with self.assertRaises(PyMongoError):
client.test.test.insert_one({"x": 1})
started_inserts = [e for e in listener.started_events if e.command_name == "insert"]
self.assertEqual(len(started_inserts), MAX_ADAPTIVE_RETRIES + 1)
def test_backoff_is_not_applied_for_non_overload_errors(self):
if _IS_SYNC:
mock_target = "pymongo.synchronous.helpers._RetryPolicy.backoff"
else:
mock_target = "pymongo.helpers._RetryPolicy.backoff"
# Create a client.
listener = OvertCommandListener()
# Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error
# code `91` (ShutdownInProgress) and `RetryableError` and `SystemOverloadedError` labels.
overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 91,
},
}
# Configure a fail point with error code `91` (ShutdownInProgress) with only the `RetryableError` error label.
non_overload_fail_point = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 91,
"errorLabels": ["RetryableError", "RetryableWriteError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the fail point command only if the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(non_overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = self.rs_client(event_listeners=[listener])
self.configure_fail_point_sync(overload_fail_point)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Perform a findOne operation with coll. Expect the operation to fail.
with mock.patch(mock_target, return_value=0) as mock_backoff:
with self.assertRaises(PyMongoError):
client.test.test.insert_one({})
# Assert that backoff was applied only once for the initial overload error and not for the subsequent non-overload retryable errors.
self.assertEqual(mock_backoff.call_count, 1)
if __name__ == "__main__":
unittest.main()

View File

@ -15,7 +15,6 @@
"""Test the client_session module."""
from __future__ import annotations
import asyncio
import copy
import sys
import time
@ -24,8 +23,6 @@ from io import BytesIO
from test.helpers import ExceptionCatchingTask
from typing import Any, Callable, List, Set, Tuple
from pymongo.synchronous.mongo_client import MongoClient
sys.path[0:0] = [""]
from test import (
@ -45,7 +42,7 @@ from test.utils_shared import (
from bson import DBRef
from gridfs.synchronous.grid_file import GridFS, GridFSBucket
from pymongo import ASCENDING, MongoClient, _csot, monitoring
from pymongo import ASCENDING, MongoClient, monitoring
from pymongo.common import _MAX_END_SESSIONS
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
from pymongo.operations import IndexModel, InsertOne, UpdateOne
@ -938,6 +935,39 @@ class TestSession(IntegrationTest):
s2.end_session()
def test_getmore_preserves_lsid_after_session_support_lost(self):
listener = OvertCommandListener()
client = self.rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
coll = client.pymongo_test.test
coll.drop()
coll.insert_many([{"x": i} for i in range(10)])
self.addCleanup(coll.drop)
with client.start_session() as s:
cursor = coll.find({}, batch_size=2, session=s)
next(cursor)
find_event = next(e for e in listener.started_events if e.command_name == "find")
lsid = find_event.command["lsid"]
# Simulate a node stepping down: mark idle connections as not supporting sessions.
for server in client._topology._servers.values():
for conn in server.pool.conns:
conn.supports_sessions = False
listener.reset()
cursor.to_list()
getmore_events = [e for e in listener.started_events if e.command_name == "getMore"]
self.assertGreater(len(getmore_events), 0, "expected at least one getMore command")
for event in getmore_events:
self.assertIn(
"lsid", event.command, "getMore must include lsid when session is materialized"
)
self.assertEqual(
lsid, event.command["lsid"], "getMore lsid must match the session lsid from find"
)
class TestCausalConsistency(UnitTest):
listener: SessionTestListener

View File

@ -145,13 +145,11 @@ class TestSON(unittest.TestCase):
self.assertEqual(ele * 100, test_son[ele])
def test_contains_has(self):
"""has_key and __contains__"""
"""Test key membership via 'in' and __contains__."""
test_son = SON([(1, 100), (2, 200), (3, 300)])
self.assertIn(1, test_son)
self.assertIn(2, test_son, "in failed")
self.assertNotIn(22, test_son, "in succeeded when it shouldn't")
self.assertTrue(test_son.has_key(2), "has_key failed")
self.assertFalse(test_son.has_key(22), "has_key succeeded when it shouldn't")
def test_clears(self):
"""Test clear()"""

View File

@ -23,7 +23,7 @@ sys.path[0:0] = [""]
from test import client_knobs, unittest
from test.pymongo_mocks import DummyMonitor
from test.utils import MockPool, flaky
from test.utils import MockPool
from test.utils_shared import wait_until
from bson.objectid import ObjectId
@ -755,7 +755,6 @@ def wait_for_primary(topology):
class TestTopologyErrors(TopologyTest):
# Errors when calling hello.
@flaky(reason="PYTHON-5366")
def test_pool_reset(self):
# hello succeeds at first, then always raises socket error.
hello_count = [0]
@ -776,7 +775,11 @@ class TestTopologyErrors(TopologyTest):
# Pool is reset by hello failure.
t.request_check_all()
self.assertNotEqual(generation, server.pool.gen.get_overall())
# Wait for the monitor's hello failure to trigger Pool.reset() and bump the generation.
wait_until(
lambda: server.pool.gen.get_overall() != generation,
"pool reset after failed monitor check",
)
def test_hello_retry(self):
# hello succeeds at first, then raises socket error, then succeeds.

View File

@ -16,9 +16,13 @@
from __future__ import annotations
import asyncio
import random
import sys
import time
from io import BytesIO
from unittest.mock import patch
import pymongo
from gridfs.synchronous.grid_file import GridFS, GridFSBucket
from pymongo.server_selectors import writable_server_selector
from pymongo.synchronous.pool import PoolState
@ -40,7 +44,9 @@ from pymongo.errors import (
CollectionInvalid,
ConfigurationError,
ConnectionFailure,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
OperationFailure,
)
from pymongo.operations import IndexModel, InsertOne
@ -426,7 +432,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
self.configure_fail_point(client, command_args)
@client_context.require_transactions
def test_callback_raises_custom_error(self):
def test_1_callback_raises_custom_error(self):
class _MyException(Exception):
pass
@ -438,7 +444,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
s.with_transaction(raise_error)
@client_context.require_transactions
def test_callback_returns_value(self):
def test_2_callback_returns_value(self):
def callback(_):
return "Foo"
@ -466,7 +472,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
self.assertEqual(s.with_transaction(callback), "Foo")
@client_context.require_transactions
def test_callback_not_retried_after_timeout(self):
def test_3_1_callback_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test
@ -487,14 +493,16 @@ class TestTransactionsConvenientAPI(TransactionsBase):
listener.reset()
with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure):
with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@client_context.require_test_commands
@client_context.require_transactions
def test_callback_not_retried_after_commit_timeout(self):
def test_3_2_callback_not_retried_after_commit_timeout(self):
listener = OvertCommandListener()
client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test
@ -519,14 +527,16 @@ class TestTransactionsConvenientAPI(TransactionsBase):
with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure):
with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@client_context.require_test_commands
@client_context.require_transactions
def test_commit_not_retried_after_timeout(self):
def test_3_3_commit_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test
@ -548,7 +558,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(ConnectionFailure):
with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback)
# One insert for the callback and two commits (includes the automatic
@ -556,6 +566,40 @@ class TestTransactionsConvenientAPI(TransactionsBase):
self.assertEqual(
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
)
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult"))
@client_context.require_transactions
def test_callback_not_retried_after_csot_timeout(self):
listener = OvertCommandListener()
client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test
def callback(session):
coll.insert_one({}, session=session)
err: dict = {
"ok": 0,
"errmsg": "Transaction 7819 has been aborted.",
"code": 251,
"codeName": "NoSuchTransaction",
"errorLabels": ["TransientTransactionError"],
}
raise OperationFailure(err["errmsg"], err["code"], err)
# Create the collection.
coll.insert_one({})
listener.reset()
with client.start_session() as s:
with pymongo.timeout(1.0):
with self.assertRaises(ExecutionTimeout):
s.with_transaction(callback)
# At least two attempts: the original and one or more retries.
inserts = len([x for x in listener.started_command_names() if x == "insert"])
aborts = len([x for x in listener.started_command_names() if x == "abortTransaction"])
self.assertGreaterEqual(inserts, 2)
self.assertGreaterEqual(aborts, 2)
# Tested here because this supports Motor's convenient transactions API.
@client_context.require_transactions
@ -594,6 +638,63 @@ class TestTransactionsConvenientAPI(TransactionsBase):
s.with_transaction(callback)
self.assertFalse(s.in_transaction)
@client_context.require_test_commands
@client_context.require_transactions
def test_4_retry_backoff_is_enforced(self):
client = client_context.client
coll = client[self.db.name].test
end = start = no_backoff_time = 0
# Make random.random always return 0 (no backoff)
with patch.object(random, "random", return_value=0):
# set fail point to trigger transaction failure and trigger backoff
self.set_fail_point(
{
"configureFailPoint": "failCommand",
"mode": {"times": 13},
"data": {
"failCommands": ["commitTransaction"],
"errorCode": 251,
},
}
)
self.addCleanup(
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
)
def callback(session):
coll.insert_one({}, session=session)
start = time.monotonic()
with self.client.start_session() as s:
s.with_transaction(callback)
end = time.monotonic()
no_backoff_time = end - start
# Make random.random always return 1 (max backoff)
with patch.object(random, "random", return_value=1):
# set fail point to trigger transaction failure and trigger backoff
self.set_fail_point(
{
"configureFailPoint": "failCommand",
"mode": {
"times": 13
}, # sufficiently high enough such that the time effect of backoff is noticeable
"data": {
"failCommands": ["commitTransaction"],
"errorCode": 251,
},
}
)
self.addCleanup(
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
)
start = time.monotonic()
with self.client.start_session() as s:
s.with_transaction(callback)
end = time.monotonic()
self.assertLess(abs(end - start - (no_backoff_time + 2.2)), 1) # sum of 13 backoffs is 2.2
class TestOptionsInsideTransactionProse(TransactionsBase):
@client_context.require_transactions

View File

@ -0,0 +1,342 @@
{
"description": "backpressure-retryable-abort",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"topologies": [
"replicaset",
"sharded",
"load-balanced"
]
}
],
"createEntities": [
{
"client": {
"id": "client0",
"useMultipleMongoses": false,
"observeEvents": [
"commandStartedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "transaction-tests"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "test"
}
},
{
"session": {
"id": "session0",
"client": "client0"
}
}
],
"initialData": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
],
"tests": [
{
"description": "abortTransaction retries if backpressure labels are added",
"operations": [
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"abortTransaction"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "session0",
"name": "abortTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 1
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": true,
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"abortTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "abortTransaction",
"databaseName": "admin"
}
},
{
"commandStartedEvent": {
"command": {
"abortTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "abortTransaction",
"databaseName": "admin"
}
},
{
"commandStartedEvent": {
"command": {
"abortTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "abortTransaction",
"databaseName": "admin"
}
}
]
}
],
"outcome": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
]
},
{
"description": "abortTransaction is retried maxAttempts=2 times if backpressure labels are added",
"operations": [
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"abortTransaction"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "session0",
"name": "abortTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 1
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": true,
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"abortTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "abortTransaction",
"databaseName": "admin"
}
},
{
"commandStartedEvent": {
"commandName": "abortTransaction"
}
},
{
"commandStartedEvent": {
"commandName": "abortTransaction"
}
}
]
}
],
"outcome": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
]
}
]
}

View File

@ -0,0 +1,359 @@
{
"description": "backpressure-retryable-commit",
"schemaVersion": "1.4",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"topologies": [
"sharded",
"replicaset",
"load-balanced"
]
}
],
"createEntities": [
{
"client": {
"id": "client0",
"useMultipleMongoses": false,
"observeEvents": [
"commandStartedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "transaction-tests"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "test"
}
},
{
"session": {
"id": "session0",
"client": "client0"
}
}
],
"initialData": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
],
"tests": [
{
"description": "commitTransaction retries if backpressure labels are added",
"runOnRequirements": [
{
"serverless": "forbid"
}
],
"operations": [
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"commitTransaction"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "session0",
"name": "commitTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 1
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": true,
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"commitTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "commitTransaction",
"databaseName": "admin"
}
},
{
"commandStartedEvent": {
"command": {
"commitTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "commitTransaction",
"databaseName": "admin"
}
},
{
"commandStartedEvent": {
"command": {
"commitTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "commitTransaction",
"databaseName": "admin"
}
}
]
}
],
"outcome": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": [
{
"_id": 1
}
]
}
]
},
{
"description": "commitTransaction is retried maxAttempts=2 times if backpressure labels are added",
"runOnRequirements": [
{
"serverless": "forbid"
}
],
"operations": [
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"commitTransaction"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "session0",
"name": "commitTransaction",
"expectError": {
"isError": true
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 1
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": true,
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"commitTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "commitTransaction",
"databaseName": "admin"
}
},
{
"commandStartedEvent": {
"commandName": "commitTransaction"
}
},
{
"commandStartedEvent": {
"commandName": "commitTransaction"
}
}
]
}
],
"outcome": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
]
}
]
}

View File

@ -0,0 +1,313 @@
{
"description": "backpressure-retryable-reads",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"topologies": [
"replicaset",
"sharded",
"load-balanced"
]
}
],
"createEntities": [
{
"client": {
"id": "client0",
"useMultipleMongoses": false,
"observeEvents": [
"commandStartedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "transaction-tests"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "test"
}
},
{
"session": {
"id": "session0",
"client": "client0"
}
}
],
"initialData": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
],
"tests": [
{
"description": "reads are retried if backpressure labels are added",
"operations": [
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"find"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "collection0",
"name": "find",
"arguments": {
"filter": {},
"session": "session0"
}
},
{
"object": "session0",
"name": "commitTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 1
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": true,
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"find": "test",
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "find",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"find": "test",
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "find",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"abortTransaction": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "commitTransaction",
"databaseName": "admin"
}
}
]
}
]
},
{
"description": "reads are retried maxAttempts=2 times if backpressure labels are added",
"operations": [
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"find"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "collection0",
"name": "find",
"arguments": {
"filter": {},
"session": "session0"
},
"expectError": {
"isError": true
}
},
{
"object": "session0",
"name": "abortTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"commandName": "abortTransaction"
}
}
]
}
]
}
]
}

View File

@ -0,0 +1,439 @@
{
"description": "backpressure-retryable-writes",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"topologies": [
"replicaset",
"sharded",
"load-balanced"
]
}
],
"createEntities": [
{
"client": {
"id": "client0",
"useMultipleMongoses": false,
"observeEvents": [
"commandStartedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "transaction-tests"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "test"
}
},
{
"session": {
"id": "session0",
"client": "client0"
}
}
],
"initialData": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
],
"tests": [
{
"description": "writes are retried if backpressure labels are added",
"operations": [
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"insert"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 2
}
}
},
{
"object": "session0",
"name": "commitTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 1
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": true,
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 2
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 2
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"abortTransaction": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "commitTransaction",
"databaseName": "admin"
}
}
]
}
],
"outcome": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
]
},
{
"description": "writes are retried maxAttempts=2 times if backpressure labels are added",
"operations": [
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"insert"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 2
}
},
"expectError": {
"isError": true
}
},
{
"object": "session0",
"name": "abortTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"commandName": "abortTransaction"
}
}
]
}
],
"outcome": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
]
},
{
"description": "retry succeeds if backpressure labels are added to the first operation in a transaction",
"operations": [
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"insert"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 2
}
}
},
{
"object": "session0",
"name": "abortTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"startTransaction": true
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"startTransaction": true
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"startTransaction": {
"$$exists": false
}
},
"commandName": "abortTransaction",
"databaseName": "admin"
}
}
]
}
],
"outcome": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
]
}
]
}

View File

@ -1451,11 +1451,6 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
self.assertListEqual(sorted_expected_documents, actual_documents)
def run_scenario(self, spec, uri=None):
# Kill all sessions before and after each test to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
self.kill_all_sessions()
# Handle flaky tests.
flaky_tests = [
("PYTHON-5170", ".*test_discovery_and_monitoring.*"),
@ -1491,6 +1486,15 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
if skip_reason is not None:
raise unittest.SkipTest(f"{skip_reason}")
# Kill all sessions after each test with transactions to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
for op in spec["operations"]:
name = op["name"]
if name == "startTransaction" or name == "withTransaction":
self.addCleanup(self.kill_all_sessions)
break
# process createEntities
self._uri = uri
self.entity_map = EntityMapUtil(self)

View File

@ -0,0 +1,66 @@
{
"tests": [
{
"description": "maxAdaptiveRetries is parsed correctly",
"uri": "mongodb://example.com/?maxAdaptiveRetries=3",
"valid": true,
"warning": false,
"hosts": null,
"auth": null,
"options": {
"maxAdaptiveRetries": 3
}
},
{
"description": "maxAdaptiveRetries=0 is parsed correctly",
"uri": "mongodb://example.com/?maxAdaptiveRetries=0",
"valid": true,
"warning": false,
"hosts": null,
"auth": null,
"options": {
"maxAdaptiveRetries": 0
}
},
{
"description": "maxAdaptiveRetries with invalid value causes a warning",
"uri": "mongodb://example.com/?maxAdaptiveRetries=-5",
"valid": true,
"warning": true,
"hosts": null,
"auth": null,
"options": null
},
{
"description": "enableOverloadRetargeting is parsed correctly",
"uri": "mongodb://example.com/?enableOverloadRetargeting=true",
"valid": true,
"warning": false,
"hosts": null,
"auth": null,
"options": {
"enableOverloadRetargeting": true
}
},
{
"description": "enableOverloadRetargeting=false is parsed correctly",
"uri": "mongodb://example.com/?enableOverloadRetargeting=false",
"valid": true,
"warning": false,
"hosts": null,
"auth": null,
"options": {
"enableOverloadRetargeting": false
}
},
{
"description": "enableOverloadRetargeting with invalid value causes a warning",
"uri": "mongodb://example.com/?enableOverloadRetargeting=invalid",
"valid": true,
"warning": true,
"hosts": null,
"auth": null,
"options": null
}
]
}

View File

@ -16,43 +16,13 @@
from __future__ import annotations
import asyncio
import functools
import os
import time
import unittest
from collections import abc
from inspect import iscoroutinefunction
from test import IntegrationTest, client_context, client_knobs
from test import client_context
from test.helpers import ConcurrentRunner
from test.utils_shared import (
CMAPListener,
CompareType,
EventListener,
OvertCommandListener,
ScenarioDict,
ServerAndTopologyEventListener,
camel_to_snake,
camel_to_snake_args,
parse_spec_options,
prepare_spec_arguments,
)
from typing import List
from test.utils_shared import ScenarioDict
from bson import ObjectId, decode, encode, json_util
from bson.binary import Binary
from bson.int64 import Int64
from bson.son import SON
from gridfs import GridFSBucket
from gridfs.synchronous.grid_file import GridFSBucket
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
from bson import json_util
from pymongo.lock import _cond_wait, _create_condition, _create_lock
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult, _WriteResult
from pymongo.synchronous import client_session
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.cursor import Cursor
from pymongo.write_concern import WriteConcern
_IS_SYNC = True
@ -219,595 +189,3 @@ class SpecTestCreator:
self._create_tests()
else:
asyncio.run(self._create_tests())
class SpecRunner(IntegrationTest):
mongos_clients: List
knobs: client_knobs
listener: EventListener
def setUp(self) -> None:
super().setUp()
self.mongos_clients = []
# Speed up the tests by decreasing the heartbeat frequency.
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
self.knobs.enable()
self.targets = {}
self.listener = None # type: ignore
self.pool_listener = None
self.server_listener = None
self.maxDiff = None
def tearDown(self) -> None:
self.knobs.disable()
def set_fail_point(self, command_args):
clients = self.mongos_clients if self.mongos_clients else [self.client]
for client in clients:
self.configure_fail_point(client, command_args)
def targeted_fail_point(self, session, fail_point):
"""Run the targetedFailPoint test operation.
Enable the fail point on the session's pinned mongos.
"""
clients = {c.address: c for c in self.mongos_clients}
client = clients[session._pinned_address]
self.configure_fail_point(client, fail_point)
self.addCleanup(self.set_fail_point, {"mode": "off"})
def assert_session_pinned(self, session):
"""Run the assertSessionPinned test operation.
Assert that the given session is pinned.
"""
self.assertIsNotNone(session._transaction.pinned_address)
def assert_session_unpinned(self, session):
"""Run the assertSessionUnpinned test operation.
Assert that the given session is not pinned.
"""
self.assertIsNone(session._pinned_address)
self.assertIsNone(session._transaction.pinned_address)
def assert_collection_exists(self, database, collection):
"""Run the assertCollectionExists test operation."""
db = self.client[database]
self.assertIn(collection, db.list_collection_names())
def assert_collection_not_exists(self, database, collection):
"""Run the assertCollectionNotExists test operation."""
db = self.client[database]
self.assertNotIn(collection, db.list_collection_names())
def assert_index_exists(self, database, collection, index):
"""Run the assertIndexExists test operation."""
coll = self.client[database][collection]
self.assertIn(index, [doc["name"] for doc in coll.list_indexes()])
def assert_index_not_exists(self, database, collection, index):
"""Run the assertIndexNotExists test operation."""
coll = self.client[database][collection]
self.assertNotIn(index, [doc["name"] for doc in coll.list_indexes()])
def wait(self, ms):
"""Run the "wait" test operation."""
time.sleep(ms / 1000.0)
def assertErrorLabelsContain(self, exc, expected_labels):
labels = [l for l in expected_labels if exc.has_error_label(l)]
self.assertEqual(labels, expected_labels)
def assertErrorLabelsOmit(self, exc, omit_labels):
for label in omit_labels:
self.assertFalse(
exc.has_error_label(label), msg=f"error labels should not contain {label}"
)
def kill_all_sessions(self):
clients = self.mongos_clients if self.mongos_clients else [self.client]
for client in clients:
try:
client.admin.command("killAllSessions", [])
except (OperationFailure, AutoReconnect):
# "operation was interrupted" by killing the command's
# own session.
# On 8.0+ killAllSessions sometimes returns a network error.
pass
def check_command_result(self, expected_result, result):
# Only compare the keys in the expected result.
filtered_result = {}
for key in expected_result:
try:
filtered_result[key] = result[key]
except KeyError:
pass
self.assertEqual(filtered_result, expected_result)
# TODO: factor the following function with test_crud.py.
def check_result(self, expected_result, result):
if isinstance(result, _WriteResult):
for res in expected_result:
prop = camel_to_snake(res)
# SPEC-869: Only BulkWriteResult has upserted_count.
if prop == "upserted_count" and not isinstance(result, BulkWriteResult):
if result.upserted_id is not None:
upserted_count = 1
else:
upserted_count = 0
self.assertEqual(upserted_count, expected_result[res], prop)
elif prop == "inserted_ids":
# BulkWriteResult does not have inserted_ids.
if isinstance(result, BulkWriteResult):
self.assertEqual(len(expected_result[res]), result.inserted_count)
else:
# InsertManyResult may be compared to [id1] from the
# crud spec or {"0": id1} from the retryable write spec.
ids = expected_result[res]
if isinstance(ids, dict):
ids = [ids[str(i)] for i in range(len(ids))]
self.assertEqual(ids, result.inserted_ids, prop)
elif prop == "upserted_ids":
# Convert indexes from strings to integers.
ids = expected_result[res]
expected_ids = {}
for str_index in ids:
expected_ids[int(str_index)] = ids[str_index]
self.assertEqual(expected_ids, result.upserted_ids, prop)
else:
self.assertEqual(getattr(result, prop), expected_result[res], prop)
return True
else:
def _helper(expected_result, result):
if isinstance(expected_result, abc.Mapping):
for i in expected_result.keys():
self.assertEqual(expected_result[i], result[i])
elif isinstance(expected_result, list):
for i, k in zip(expected_result, result):
_helper(i, k)
else:
self.assertEqual(expected_result, result)
_helper(expected_result, result)
return None
def get_object_name(self, op):
"""Allow subclasses to override handling of 'object'
Transaction spec says 'object' is required.
"""
return op["object"]
@staticmethod
def parse_options(opts):
return parse_spec_options(opts)
def run_operation(self, sessions, collection, operation):
original_collection = collection
name = camel_to_snake(operation["name"])
if name == "run_command":
name = "command"
elif name == "download_by_name":
name = "open_download_stream_by_name"
elif name == "download":
name = "open_download_stream"
elif name == "map_reduce":
self.skipTest("PyMongo does not support mapReduce")
elif name == "count":
self.skipTest("PyMongo does not support count")
database = collection.database
collection = database.get_collection(collection.name)
if "collectionOptions" in operation:
collection = collection.with_options(
**self.parse_options(operation["collectionOptions"])
)
object_name = self.get_object_name(operation)
if object_name == "gridfsbucket":
# Only create the GridFSBucket when we need it (for the gridfs
# retryable reads tests).
obj = GridFSBucket(database, bucket_name=collection.name)
else:
objects = {
"client": database.client,
"database": database,
"collection": collection,
"testRunner": self,
}
objects.update(sessions)
obj = objects[object_name]
# Combine arguments with options and handle special cases.
arguments = operation.get("arguments", {})
arguments.update(arguments.pop("options", {}))
self.parse_options(arguments)
cmd = getattr(obj, name)
with_txn_callback = functools.partial(
self.run_operations, sessions, original_collection, in_with_transaction=True
)
prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback)
if name == "run_on_thread":
args = {"sessions": sessions, "collection": collection}
args.update(arguments)
arguments = args
if not _IS_SYNC and iscoroutinefunction(cmd):
result = cmd(**dict(arguments))
else:
result = cmd(**dict(arguments))
# Cleanup open change stream cursors.
if name == "watch":
self.addCleanup(result.close)
if name == "aggregate":
if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]:
# Read from the primary to ensure causal consistency.
out = collection.database.get_collection(
arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY
)
return out.find()
if "download" in name:
result = Binary(result.read())
if isinstance(result, Cursor) or isinstance(result, CommandCursor):
return result.to_list()
return result
def allowable_errors(self, op):
"""Allow encryption spec to override expected error classes."""
return (PyMongoError,)
def _run_op(self, sessions, collection, op, in_with_transaction):
expected_result = op.get("result")
if expect_error(op):
with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context:
self.run_operation(sessions, collection, op.copy())
exc = context.exception
if expect_error_message(expected_result):
if isinstance(exc, BulkWriteError):
errmsg = str(exc.details).lower()
else:
errmsg = str(exc).lower()
self.assertIn(expected_result["errorContains"].lower(), errmsg)
if expect_error_code(expected_result):
self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName"))
if expect_error_labels_contain(expected_result):
self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"])
if expect_error_labels_omit(expected_result):
self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"])
if expect_timeout_error(expected_result):
self.assertIsInstance(exc, PyMongoError)
if not exc.timeout:
# Re-raise the exception for better diagnostics.
raise exc
# Reraise the exception if we're in the with_transaction
# callback.
if in_with_transaction:
raise context.exception
else:
result = self.run_operation(sessions, collection, op.copy())
if "result" in op:
if op["name"] == "runCommand":
self.check_command_result(expected_result, result)
else:
self.check_result(expected_result, result)
def run_operations(self, sessions, collection, ops, in_with_transaction=False):
for op in ops:
self._run_op(sessions, collection, op, in_with_transaction)
# TODO: factor with test_command_monitoring.py
def check_events(self, test, listener, session_ids):
events = listener.started_events
if not len(test["expectations"]):
return
# Give a nicer message when there are missing or extra events
cmds = decode_raw([event.command for event in events])
self.assertEqual(len(events), len(test["expectations"]), cmds)
for i, expectation in enumerate(test["expectations"]):
event_type = next(iter(expectation))
event = events[i]
# The tests substitute 42 for any number other than 0.
if event.command_name == "getMore" and event.command["getMore"]:
event.command["getMore"] = Int64(42)
elif event.command_name == "killCursors":
event.command["cursors"] = [Int64(42)]
elif event.command_name == "update":
# TODO: remove this once PYTHON-1744 is done.
# Add upsert and multi fields back into expectations.
updates = expectation[event_type]["command"]["updates"]
for update in updates:
update.setdefault("upsert", False)
update.setdefault("multi", False)
# Replace afterClusterTime: 42 with actual afterClusterTime.
expected_cmd = expectation[event_type]["command"]
expected_read_concern = expected_cmd.get("readConcern")
if expected_read_concern is not None:
time = expected_read_concern.get("afterClusterTime")
if time == 42:
actual_time = event.command.get("readConcern", {}).get("afterClusterTime")
if actual_time is not None:
expected_read_concern["afterClusterTime"] = actual_time
recovery_token = expected_cmd.get("recoveryToken")
if recovery_token == 42:
expected_cmd["recoveryToken"] = CompareType(dict)
# Replace lsid with a name like "session0" to match test.
if "lsid" in event.command:
for name, lsid in session_ids.items():
if event.command["lsid"] == lsid:
event.command["lsid"] = name
break
for attr, expected in expectation[event_type].items():
actual = getattr(event, attr)
expected = wrap_types(expected)
if isinstance(expected, dict):
for key, val in expected.items():
if val is None:
if key in actual:
self.fail(f"Unexpected key [{key}] in {actual!r}")
elif key not in actual:
self.fail(f"Expected key [{key}] in {actual!r}")
else:
self.assertEqual(
val, decode_raw(actual[key]), f"Key [{key}] in {actual}"
)
else:
self.assertEqual(actual, expected)
def maybe_skip_scenario(self, test):
if test.get("skipReason"):
self.skipTest(test.get("skipReason"))
def get_scenario_db_name(self, scenario_def):
"""Allow subclasses to override a test's database name."""
return scenario_def["database_name"]
def get_scenario_coll_name(self, scenario_def):
"""Allow subclasses to override a test's collection name."""
return scenario_def["collection_name"]
def get_outcome_coll_name(self, outcome, collection):
"""Allow subclasses to override outcome collection."""
return collection.name
def run_test_ops(self, sessions, collection, test):
"""Added to allow retryable writes spec to override a test's
operation.
"""
self.run_operations(sessions, collection, test["operations"])
def parse_client_options(self, opts):
"""Allow encryption spec to override a clientOptions parsing."""
return opts
def setup_scenario(self, scenario_def):
"""Allow specs to override a test's setup."""
db_name = self.get_scenario_db_name(scenario_def)
coll_name = self.get_scenario_coll_name(scenario_def)
documents = scenario_def["data"]
# Setup the collection with as few majority writes as possible.
db = client_context.client.get_database(db_name)
coll_exists = bool(db.list_collection_names(filter={"name": coll_name}))
if coll_exists:
db[coll_name].delete_many({})
# Only use majority wc only on the final write.
wc = WriteConcern(w="majority")
if documents:
db.get_collection(coll_name, write_concern=wc).insert_many(documents)
elif not coll_exists:
# Ensure collection exists.
db.create_collection(coll_name, write_concern=wc)
def run_scenario(self, scenario_def, test):
self.maybe_skip_scenario(test)
# Kill all sessions before and after each test to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
self.kill_all_sessions()
self.addCleanup(self.kill_all_sessions)
self.setup_scenario(scenario_def)
database_name = self.get_scenario_db_name(scenario_def)
collection_name = self.get_scenario_coll_name(scenario_def)
# SPEC-1245 workaround StaleDbVersion on distinct
for c in self.mongos_clients:
c[database_name][collection_name].distinct("x")
# Configure the fail point before creating the client.
if "failPoint" in test:
fp = test["failPoint"]
self.set_fail_point(fp)
self.addCleanup(
self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"}
)
listener = OvertCommandListener()
pool_listener = CMAPListener()
server_listener = ServerAndTopologyEventListener()
# Create a new client, to avoid interference from pooled sessions.
client_options = self.parse_client_options(test["clientOptions"])
use_multi_mongos = test["useMultipleMongoses"]
host = None
if use_multi_mongos:
if client_context.load_balancer:
host = client_context.MULTI_MONGOS_LB_URI
elif client_context.is_mongos:
host = client_context.mongos_seeds()
client = self.rs_client(
h=host, event_listeners=[listener, pool_listener, server_listener], **client_options
)
self.scenario_client = client
self.listener = listener
self.pool_listener = pool_listener
self.server_listener = server_listener
# Create session0 and session1.
sessions = {}
session_ids = {}
for i in range(2):
# Don't attempt to create sessions if they are not supported by
# the running server version.
if not client_context.sessions_enabled:
break
session_name = "session%d" % i
opts = camel_to_snake_args(test["sessionOptions"][session_name])
if "default_transaction_options" in opts:
txn_opts = self.parse_options(opts["default_transaction_options"])
txn_opts = client_session.TransactionOptions(**txn_opts)
opts["default_transaction_options"] = txn_opts
s = client.start_session(**dict(opts))
sessions[session_name] = s
# Store lsid so we can access it after end_session, in check_events.
session_ids[session_name] = s.session_id
self.addCleanup(end_sessions, sessions)
collection = client[database_name][collection_name]
self.run_test_ops(sessions, collection, test)
end_sessions(sessions)
self.check_events(test, listener, session_ids)
# Disable fail points.
if "failPoint" in test:
fp = test["failPoint"]
self.set_fail_point({"configureFailPoint": fp["configureFailPoint"], "mode": "off"})
# Assert final state is expected.
outcome = test["outcome"]
expected_c = outcome.get("collection")
if expected_c is not None:
outcome_coll_name = self.get_outcome_coll_name(outcome, collection)
# Read from the primary with local read concern to ensure causal
# consistency.
outcome_coll = client_context.client[collection.database.name].get_collection(
outcome_coll_name,
read_preference=ReadPreference.PRIMARY,
read_concern=ReadConcern("local"),
)
actual_data = outcome_coll.find(sort=[("_id", 1)]).to_list()
# The expected data needs to be the left hand side here otherwise
# CompareType(Binary) doesn't work.
self.assertEqual(wrap_types(expected_c["data"]), actual_data)
def expect_any_error(op):
if isinstance(op, dict):
return op.get("error")
return False
def expect_error_message(expected_result):
if isinstance(expected_result, dict):
return isinstance(expected_result["errorContains"], str)
return False
def expect_error_code(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorCodeName"]
return False
def expect_error_labels_contain(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorLabelsContain"]
return False
def expect_error_labels_omit(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorLabelsOmit"]
return False
def expect_timeout_error(expected_result):
if isinstance(expected_result, dict):
return expected_result["isTimeoutError"]
return False
def expect_error(op):
expected_result = op.get("result")
return (
expect_any_error(op)
or expect_error_message(expected_result)
or expect_error_code(expected_result)
or expect_error_labels_contain(expected_result)
or expect_error_labels_omit(expected_result)
or expect_timeout_error(expected_result)
)
def end_sessions(sessions):
for s in sessions.values():
# Aborts the transaction if it's open.
s.end_session()
def decode_raw(val):
"""Decode RawBSONDocuments in the given container."""
if isinstance(val, (list, abc.Mapping)):
return decode(encode({"v": val}))["v"]
return val
TYPES = {
"binData": Binary,
"long": Int64,
"int": int,
"string": str,
"objectId": ObjectId,
"object": dict,
"array": list,
}
def wrap_types(val):
"""Support $$type assertion in command results."""
if isinstance(val, list):
return [wrap_types(v) for v in val]
if isinstance(val, abc.Mapping):
typ = val.get("$$type")
if typ:
if isinstance(typ, str):
types = TYPES[typ]
else:
types = tuple(TYPES[t] for t in typ)
return CompareType(types)
d = {}
for key in val:
d[key] = wrap_types(val[key])
return d
return val

View File

@ -213,6 +213,7 @@ converted_tests = [
"test_bulk.py",
"test_change_stream.py",
"test_client.py",
"test_client_backpressure.py",
"test_client_bulk_write.py",
"test_client_context.py",
"test_client_metadata.py",
@ -350,7 +351,7 @@ def translate_async_sleeps(lines: list[str]) -> list[str]:
sleeps = [line for line in lines if "asyncio.sleep" in line]
for line in sleeps:
res = re.search(r"asyncio.sleep\(([^()]*)\)", line)
res = re.search(r"asyncio\.sleep\(\s*(.*?)\)", line)
if res:
old = res[0]
index = lines.index(line)