Merge branch 'master' into fix/int32-overflow-opmsg

This commit is contained in:
Jib 2026-04-03 13:21:40 -04:00 committed by GitHub
commit 5c0ee69fe5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
47 changed files with 4096 additions and 1405 deletions

View File

@ -7,6 +7,8 @@ import subprocess
from argparse import Namespace
from subprocess import CalledProcessError
JIRA_FILTER = "https://jira.mongodb.org/issues/?jql=labels%20%3D%20automated-sync%20AND%20status%20!%3D%20Closed"
def resync_specs(directory: pathlib.Path, errored: dict[str, str]) -> None:
"""Actually sync the specs"""
@ -117,6 +119,7 @@ def write_summary(errored: dict[str, str], new: list[str], filename: str | None)
pr_body += "\n -".join(new)
pr_body += "\n"
if pr_body != "":
pr_body = f"Jira tickets: {JIRA_FILTER}\n\n" + pr_body
if filename is None:
print(f"\n{pr_body}")
else:

View File

@ -153,6 +153,10 @@ def handle_test_env() -> None:
# Start compiling the args we'll pass to uv.
UV_ARGS = ["--extra test --no-group dev"]
# If USE_ACTIVE_VENV is set, add --active to UV_ARGS so run-tests.sh uses the active venv.
if is_set("USE_ACTIVE_VENV"):
UV_ARGS.append("--active")
test_title = test_name
if sub_test_name:
test_title += f" {sub_test_name}"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,460 @@
diff --git a/test/client-side-encryption/spec/unified/accessToken-azure.json b/test/client-side-encryption/spec/unified/accessToken-azure.json
new file mode 100644
index 00000000..510d8795
--- /dev/null
+++ b/test/client-side-encryption/spec/unified/accessToken-azure.json
@@ -0,0 +1,186 @@
+{
+ "description": "accessToken-azure",
+ "schemaVersion": "1.28",
+ "runOnRequirements": [
+ {
+ "minServerVersion": "4.1.10",
+ "csfle": {
+ "minLibmongocryptVersion": "1.6.0"
+ }
+ }
+ ],
+ "createEntities": [
+ {
+ "client": {
+ "id": "client",
+ "autoEncryptOpts": {
+ "keyVaultNamespace": "keyvault.datakeys",
+ "kmsProviders": {
+ "azure": {
+ "accessToken": {
+ "$$placeholder": 1
+ }
+ }
+ }
+ }
+ }
+ },
+ {
+ "database": {
+ "id": "db",
+ "client": "client",
+ "databaseName": "db"
+ }
+ },
+ {
+ "collection": {
+ "id": "coll",
+ "database": "db",
+ "collectionName": "coll"
+ }
+ },
+ {
+ "clientEncryption": {
+ "id": "clientEncryption",
+ "clientEncryptionOpts": {
+ "keyVaultClient": "client",
+ "keyVaultNamespace": "keyvault.datakeys",
+ "kmsProviders": {
+ "azure": {
+ "accessToken": {
+ "$$placeholder": 1
+ }
+ }
+ }
+ }
+ }
+ }
+ ],
+ "initialData": [
+ {
+ "databaseName": "db",
+ "collectionName": "coll",
+ "documents": [],
+ "createOptions": {
+ "validator": {
+ "$jsonSchema": {
+ "properties": {
+ "secret": {
+ "encrypt": {
+ "keyId": [
+ {
+ "$binary": {
+ "base64": "AZURE+AAAAAAAAAAAAAAAA==",
+ "subType": "04"
+ }
+ }
+ ],
+ "bsonType": "string",
+ "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"
+ }
+ }
+ },
+ "bsonType": "object"
+ }
+ }
+ }
+ },
+ {
+ "databaseName": "keyvault",
+ "collectionName": "datakeys",
+ "documents": [
+ {
+ "_id": {
+ "$binary": {
+ "base64": "AZURE+AAAAAAAAAAAAAAAA==",
+ "subType": "04"
+ }
+ },
+ "keyAltNames": [
+ "my-key"
+ ],
+ "keyMaterial": {
+ "$binary": {
+ "base64": "n+HWZ0ZSVOYA3cvQgP7inN4JSXfOH85IngmeQxRpQHjCCcqT3IFqEWNlrsVHiz3AELimHhX4HKqOLWMUeSIT6emUDDoQX9BAv8DR1+E1w4nGs/NyEneac78EYFkK3JysrFDOgl2ypCCTKAypkn9CkAx1if4cfgQE93LW4kczcyHdGiH36CIxrCDGv1UzAvERN5Qa47DVwsM6a+hWsF2AAAJVnF0wYLLJU07TuRHdMrrphPWXZsFgyV+lRqJ7DDpReKNO8nMPLV/mHqHBHGPGQiRdb9NoJo8CvokGz4+KE8oLwzKf6V24dtwZmRkrsDV4iOhvROAzz+Euo1ypSkL3mw==",
+ "subType": "00"
+ }
+ },
+ "creationDate": {
+ "$date": {
+ "$numberLong": "1552949630483"
+ }
+ },
+ "updateDate": {
+ "$date": {
+ "$numberLong": "1552949630483"
+ }
+ },
+ "status": {
+ "$numberInt": "0"
+ },
+ "masterKey": {
+ "provider": "azure",
+ "keyVaultEndpoint": "key-vault-csfle.vault.azure.net",
+ "keyName": "key-name-csfle"
+ }
+ }
+ ]
+ }
+ ],
+ "tests": [
+ {
+ "description": "Auto encrypt using access token Azure credentials",
+ "operations": [
+ {
+ "name": "insertOne",
+ "arguments": {
+ "document": {
+ "_id": 1,
+ "secret": "string0"
+ }
+ },
+ "object": "coll"
+ }
+ ],
+ "outcome": [
+ {
+ "documents": [
+ {
+ "_id": 1,
+ "secret": {
+ "$binary": {
+ "base64": "AQGVERPgAAAAAAAAAAAAAAAC5DbBSwPwfSlBrDtRuglvNvCXD1KzDuCKY2P+4bRFtHDjpTOE2XuytPAUaAbXf1orsPq59PVZmsbTZbt2CB8qaQ==",
+ "subType": "06"
+ }
+ }
+ }
+ ],
+ "collectionName": "coll",
+ "databaseName": "db"
+ }
+ ]
+ },
+ {
+ "description": "Explicit encrypt using access token Azure credentials",
+ "operations": [
+ {
+ "name": "encrypt",
+ "object": "clientEncryption",
+ "arguments": {
+ "value": "string0",
+ "opts": {
+ "keyAltName": "my-key",
+ "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"
+ }
+ },
+ "expectResult": {
+ "$binary": {
+ "base64": "AQGVERPgAAAAAAAAAAAAAAAC5DbBSwPwfSlBrDtRuglvNvCXD1KzDuCKY2P+4bRFtHDjpTOE2XuytPAUaAbXf1orsPq59PVZmsbTZbt2CB8qaQ==",
+ "subType": "06"
+ }
+ }
+ }
+ ]
+ }
+ ]
+}
diff --git a/test/client-side-encryption/spec/unified/accessToken-gcp.json b/test/client-side-encryption/spec/unified/accessToken-gcp.json
new file mode 100644
index 00000000..f5cf8914
--- /dev/null
+++ b/test/client-side-encryption/spec/unified/accessToken-gcp.json
@@ -0,0 +1,188 @@
+{
+ "description": "accessToken-gcp",
+ "schemaVersion": "1.28",
+ "runOnRequirements": [
+ {
+ "minServerVersion": "4.1.10",
+ "csfle": {
+ "minLibmongocryptVersion": "1.6.0"
+ }
+ }
+ ],
+ "createEntities": [
+ {
+ "client": {
+ "id": "client",
+ "autoEncryptOpts": {
+ "keyVaultNamespace": "keyvault.datakeys",
+ "kmsProviders": {
+ "gcp": {
+ "accessToken": {
+ "$$placeholder": 1
+ }
+ }
+ }
+ }
+ }
+ },
+ {
+ "database": {
+ "id": "db",
+ "client": "client",
+ "databaseName": "db"
+ }
+ },
+ {
+ "collection": {
+ "id": "coll",
+ "database": "db",
+ "collectionName": "coll"
+ }
+ },
+ {
+ "clientEncryption": {
+ "id": "clientEncryption",
+ "clientEncryptionOpts": {
+ "keyVaultClient": "client",
+ "keyVaultNamespace": "keyvault.datakeys",
+ "kmsProviders": {
+ "gcp": {
+ "accessToken": {
+ "$$placeholder": 1
+ }
+ }
+ }
+ }
+ }
+ }
+ ],
+ "initialData": [
+ {
+ "databaseName": "db",
+ "collectionName": "coll",
+ "documents": [],
+ "createOptions": {
+ "validator": {
+ "$jsonSchema": {
+ "properties": {
+ "secret": {
+ "encrypt": {
+ "keyId": [
+ {
+ "$binary": {
+ "base64": "GCP+AAAAAAAAAAAAAAAAAA==",
+ "subType": "04"
+ }
+ }
+ ],
+ "bsonType": "string",
+ "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"
+ }
+ }
+ },
+ "bsonType": "object"
+ }
+ }
+ }
+ },
+ {
+ "databaseName": "keyvault",
+ "collectionName": "datakeys",
+ "documents": [
+ {
+ "_id": {
+ "$binary": {
+ "base64": "GCP+AAAAAAAAAAAAAAAAAA==",
+ "subType": "04"
+ }
+ },
+ "keyAltNames": [
+ "my-key"
+ ],
+ "keyMaterial": {
+ "$binary": {
+ "base64": "CiQAIgLj0WyktnB4dfYHo5SLZ41K4ASQrjJUaSzl5vvVH0G12G0SiQEAjlV8XPlbnHDEDFbdTO4QIe8ER2/172U1ouLazG0ysDtFFIlSvWX5ZnZUrRMmp/R2aJkzLXEt/zf8Mn4Lfm+itnjgo5R9K4pmPNvvPKNZX5C16lrPT+aA+rd+zXFSmlMg3i5jnxvTdLHhg3G7Q/Uv1ZIJskKt95bzLoe0tUVzRWMYXLIEcohnQg==",
+ "subType": "00"
+ }
+ },
+ "creationDate": {
+ "$date": {
+ "$numberLong": "1552949630483"
+ }
+ },
+ "updateDate": {
+ "$date": {
+ "$numberLong": "1552949630483"
+ }
+ },
+ "status": {
+ "$numberInt": "0"
+ },
+ "masterKey": {
+ "provider": "gcp",
+ "projectId": "devprod-drivers",
+ "location": "global",
+ "keyRing": "key-ring-csfle",
+ "keyName": "key-name-csfle"
+ }
+ }
+ ]
+ }
+ ],
+ "tests": [
+ {
+ "description": "Auto encrypt using access token GCP credentials",
+ "operations": [
+ {
+ "name": "insertOne",
+ "arguments": {
+ "document": {
+ "_id": 1,
+ "secret": "string0"
+ }
+ },
+ "object": "coll"
+ }
+ ],
+ "outcome": [
+ {
+ "documents": [
+ {
+ "_id": 1,
+ "secret": {
+ "$binary": {
+ "base64": "ARgj/gAAAAAAAAAAAAAAAAACwFd+Y5Ojw45GUXNvbcIpN9YkRdoHDHkR4kssdn0tIMKlDQOLFkWFY9X07IRlXsxPD8DcTiKnl6XINK28vhcGlg==",
+ "subType": "06"
+ }
+ }
+ }
+ ],
+ "collectionName": "coll",
+ "databaseName": "db"
+ }
+ ]
+ },
+ {
+ "description": "Explicit encrypt using access token GCP credentials",
+ "operations": [
+ {
+ "name": "encrypt",
+ "object": "clientEncryption",
+ "arguments": {
+ "value": "string0",
+ "opts": {
+ "keyAltName": "my-key",
+ "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"
+ }
+ },
+ "expectResult": {
+ "$binary": {
+ "base64": "ARgj/gAAAAAAAAAAAAAAAAACwFd+Y5Ojw45GUXNvbcIpN9YkRdoHDHkR4kssdn0tIMKlDQOLFkWFY9X07IRlXsxPD8DcTiKnl6XINK28vhcGlg==",
+ "subType": "06"
+ }
+ }
+ }
+ ]
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-azure-accessToken-type.json b/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-azure-accessToken-type.json
new file mode 100644
index 00000000..8fe5c150
--- /dev/null
+++ b/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-azure-accessToken-type.json
@@ -0,0 +1,31 @@
+{
+ "description": "clientEncryptionOpts-kmsProviders-azure-accessToken-type",
+ "schemaVersion": "1.28",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ },
+ {
+ "clientEncryption": {
+ "id": "clientEncryption0",
+ "clientEncryptionOpts": {
+ "keyVaultClient": "client0",
+ "keyVaultNamespace": "keyvault.datakeys",
+ "kmsProviders": {
+ "azure": {
+ "accessToken": 0
+ }
+ }
+ }
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "",
+ "operations": []
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-gcp-accessToken-type.json b/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-gcp-accessToken-type.json
new file mode 100644
index 00000000..2284e26c
--- /dev/null
+++ b/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-gcp-accessToken-type.json
@@ -0,0 +1,31 @@
+{
+ "description": "clientEncryptionOpts-kmsProviders-gcp-accessToken-type",
+ "schemaVersion": "1.28",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ },
+ {
+ "clientEncryption": {
+ "id": "clientEncryption0",
+ "clientEncryptionOpts": {
+ "keyVaultClient": "client0",
+ "keyVaultNamespace": "keyvault.datakeys",
+ "kmsProviders": {
+ "gcp": {
+ "accessToken": 0
+ }
+ }
+ }
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "",
+ "operations": []
+ }
+ ]
+}

44
.github/copilot-instructions.md vendored Normal file
View File

@ -0,0 +1,44 @@
When reviewing code, focus on:
## Security Critical Issues
- Check for hardcoded secrets, API keys, or credentials.
- Check for instances of potential method call injection, dynamic code execution, symbol injection or other code injection vulnerabilities.
## Performance Red Flags
- Spot inefficient loops and algorithmic issues.
- Check for memory leaks and resource cleanup.
## Code Quality Essentials
- Methods should be focused and appropriately sized. If a method is doing too much, suggest refactorings to split it up.
- Use clear, descriptive naming conventions.
- Avoid encapsulation violations and ensure proper separation of concerns.
- All public classes, modules, and methods should have clear documentation in Sphinx format.
## PyMongo-specific Concerns
- Do not review files within `pymongo/synchronous` or files in `test/` that also have a file of the same name in `test/asynchronous` unless the reviewed changes include a `_IS_SYNC` statement. PyMongo generates these files from `pymongo/asynchronous` and `test/asynchronous` using `tools/synchro.py`.
- All asynchronous functions must not call any blocking I/O.
## Review Style
- Be specific and actionable in feedback.
- Explain the "why" behind recommendations.
- Acknowledge good patterns when you see them.
- Ask clarifying questions when code intent is unclear.
Always prioritize security vulnerabilities and performance issues that could impact users.
Always suggest changes to improve readability and testability. For example, this suggestion seeks to make the code more readable, reusable, and testable:
```python
# Instead of:
if user.email and "@" in user.email and len(user.email) > 5:
submit_button.enabled = True
else:
submit_button.enabled = False
# Consider:
def valid_email(email):
return email and "@" in email and len(email) > 5
submit_button.enabled = valid_email(user.email)
```

View File

@ -6,8 +6,8 @@ If you are an external contributor and there is no JIRA ticket associated with y
for the PR title. A MongoDB employee will create a JIRA ticket and edit the name and links as appropriate.
Note on AI Contributions:
We do not accept pull requests that are primarily or substantially generated by AI tools (ChatGPT, Copilot, etc.).
All contributions must be written and understood by human contributors.
We only accept pull requests that are authored and submitted by human contributors who fully understand the changes they are proposing.
All contributions must be written and understood by human contributors. Please read about our policy in our contributing guide.
-->
[JIRA TICKET]

View File

@ -61,7 +61,7 @@ jobs:
- name: Set up QEMU
if: runner.os == 'Linux'
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
with:
# setup-qemu-action by default uses `tonistiigi/binfmt:latest` image,
# which is out of date. This causes seg faults during build.
@ -92,7 +92,7 @@ jobs:
# Free-threading builds:
ls wheelhouse/*cp314t*.whl
- uses: actions/upload-artifact@v6
- uses: actions/upload-artifact@v7
with:
name: wheel-${{ matrix.buildplat[1] }}
path: ./wheelhouse/*.whl
@ -125,7 +125,7 @@ jobs:
cd ..
python -c "from pymongo import has_c; assert has_c()"
- uses: actions/upload-artifact@v6
- uses: actions/upload-artifact@v7
with:
name: "sdist"
path: ./dist/*.tar.gz
@ -136,13 +136,13 @@ jobs:
name: Download Wheels
steps:
- name: Download all workflow run artifacts
uses: actions/download-artifact@v7
uses: actions/download-artifact@v8
- name: Flatten directory
working-directory: .
run: |
find . -mindepth 2 -type f -exec mv {} . \;
find . -type d -empty -delete
- uses: actions/upload-artifact@v6
- uses: actions/upload-artifact@v7
with:
name: all-dist-${{ github.run_id }}
path: "./*"

View File

@ -75,7 +75,7 @@ jobs:
id-token: write
steps:
- name: Download all the dists
uses: actions/download-artifact@v7
uses: actions/download-artifact@v8
with:
name: all-dist-${{ github.run_id }}
path: dist/

View File

@ -67,7 +67,7 @@ jobs:
run: rm -rf .venv .venv-sbom sbom-requirements.txt
- name: Upload SBOM artifact
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v7
with:
name: sbom
path: sbom.json

View File

@ -26,7 +26,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: "3.10"
@ -68,7 +68,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -90,7 +90,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: "3.10"
@ -118,7 +118,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: "3.10"
@ -143,7 +143,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: "3.10"
@ -162,7 +162,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: "3.10"
@ -184,7 +184,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: "${{matrix.python}}"
@ -205,7 +205,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
enable-cache: true
python-version: "3.10"
@ -245,7 +245,7 @@ jobs:
run: |
pip install build
python -m build --sdist
- uses: actions/upload-artifact@v6
- uses: actions/upload-artifact@v7
with:
name: "sdist"
path: dist/*.tar.gz
@ -257,7 +257,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Download sdist
uses: actions/download-artifact@v7
uses: actions/download-artifact@v8
with:
path: sdist/
- name: Unpack SDist
@ -295,7 +295,7 @@ jobs:
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7
with:
python-version: "3.9"
- id: setup-mongodb

View File

@ -18,4 +18,4 @@ jobs:
with:
persist-credentials: false
- name: Run zizmor 🌈
uses: zizmorcore/zizmor-action@0dce2577a4760a2749d8cfb7a84b7d5585ebcb7d # v0.5.0
uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2

1
.gitignore vendored
View File

@ -43,3 +43,4 @@ test/lambda/*.json
xunit-results/
coverage.xml
server.log
.coverage

View File

@ -85,49 +85,53 @@ likelihood for getting review sooner shoots up.
- `versionadded:: 3.11`
- `versionchanged:: 3.5`
**Pull Request Template Breakdown**
### AI-Generated Contributions Policy
- **Github PR Title**
#### Our Stance
- The PR Title format should always be
`[JIRA-ID] : Jira Title or Blurb Summary`.
We only accept pull requests that are authored and submitted by human contributors who fully understand the changes they are proposing. Pull requests that are not clearly owned and understood by a human contributor may be closed. **All contributions must be submitted, reviewed, and understood by human contributors.**
- **JIRA LINK**
##### Why This Policy Exists
- Convenient link to the associated JIRA ticket.
At MongoDB, we understand the power and prevalence of AI tools in software development. With that being said, many MongoDB libraries are foundational tools used in production systems worldwide. The nature of these libraries requires:
- **Summary**
- **Deep domain expertise**: MongoDB's wire protocol, BSON specification, connection pooling, authentication mechanisms, and concurrency patterns require an understanding that AI alone cannot substantiate.
- Small blurb on why this is needed. The JIRA task should have
the more in-depth description, but this should still, at a
high level, give anyone looking an understanding of why the
PR has been checked in.
- **Long-term maintainability**: Contributors need to be able to explain *why* code is written a certain way, explain design decisions, and be available to iterate on their contributions.
- **Changes in this PR**
- **Security responsibility**: Authentication, credential handling, and TLS implementation cannot be left to probabilistic code generation.
- The explicit code changes that this PR is introducing. This
should be more specific than just the task name. (Unless the
task name is very clear).
##### What This Means for Contributors
- **Test Plan**
**Required:**
- Everything needs a test description. Describe what you did
to validate your changes actually worked; if you did
nothing, then document you did not test it. Aim to make
these steps reproducible by other engineers, specifically
with your primary reviewer in mind.
- Full understanding of every line of code you submit
- Ability to explain and defend your implementation choices
- Willingness to iterate and maintain your contributions
- **Screenshots**
**Encouraged:**
- Any images that provide more context to the PR. Usually,
these just coincide with the test plan.
- Using AI assistants as learning tools to understand concepts
- IDE autocomplete features that suggest standard patterns
- AI help for brainstorming approaches (but write the code yourself)
- Writing code using AI tools, reviewing each line and revising code as necessary.
- **Callouts or follow-up items**
**Not allowed:**
- This is a good place for identifying "to-dos" that you've
placed in the code (Must have an accompanying JIRA Ticket).
- Potential bugs that you are unsure how to test in the code.
- Opinions you want to receive about your code.
- Submitting PRs generated solely by AI tools
- Copy-pasting AI-generated code without full understanding
##### Disclosure
If you used AI assistance in any way during your contribution, please disclose what the AI assistant was used for in your PR description. We would love to know what tools developers have found useful in iterating in their day to day.
##### Questions?
If you're unsure whether your contribution complies with this policy, please ask for guidance within the scope of the PR and clarify any uncertainty. We're happy to guide contributors toward successful contributions.
---
*This policy helps us maintain the reliability, security, and trustworthiness that production applications depend on. Thank you for understanding and for contributing thoughtfully to PyMongo.*
## Running Linters
@ -205,6 +209,7 @@ the pages will re-render and the browser will automatically refresh.
and the `<class_name>` to test a full module. For example:
`just test test/test_change_stream.py::TestUnifiedChangeStreamsErrors::test_change_stream_errors_on_ElectionInProgress`.
- Use the `-k` argument to select tests by pattern.
- Run `just test-coverage` to run tests with coverage and display a report. After running tests with coverage, use `just coverage-html` to generate an HTML report in `htmlcov/index.html`.
## Running tests that require secrets, services, or other configuration

View File

@ -4,6 +4,7 @@
[![Python Versions](https://img.shields.io/pypi/pyversions/pymongo)](https://pypi.org/project/pymongo)
[![Monthly Downloads](https://static.pepy.tech/badge/pymongo/month)](https://pepy.tech/project/pymongo)
[![API Documentation Status](https://readthedocs.org/projects/pymongo/badge/?version=stable)](http://pymongo.readthedocs.io/en/stable/api?badge=stable)
[![codecov](https://codecov.io/gh/mongodb/mongo-python-driver/graph/badge.svg?branch=master)](https://codecov.io/gh/mongodb/mongo-python-driver)
## About

View File

@ -356,7 +356,8 @@ static PyObject* datetime_ms_from_millis(PyObject* self, long long millis){
if (!(ll_millis = PyLong_FromLongLong(millis))){
return NULL;
}
dt = PyObject_CallFunctionObjArgs(state->DatetimeMS, ll_millis, NULL);
PyObject* args[1] = {ll_millis};
dt = PyObject_Vectorcall(state->DatetimeMS, args, 1, NULL);
Py_DECREF(ll_millis);
return dt;
}
@ -401,7 +402,9 @@ static PyObject* decode_datetime(PyObject* self, long long millis, const codec_o
int64_t min_millis_offset = 0;
int64_t max_millis_offset = 0;
if (options->tz_aware && options->tzinfo && options->tzinfo != Py_None) {
PyObject* utcoffset = PyObject_CallMethodObjArgs(options->tzinfo, state->_utcoffset_str, state->min_datetime, NULL);
PyObject* utcoffset_args[2] = {options->tzinfo, state->min_datetime};
PyObject* utcoffset = PyObject_VectorcallMethod(
state->_utcoffset_str, utcoffset_args, 2, NULL);
if (utcoffset == NULL) {
return 0;
}
@ -420,7 +423,9 @@ static PyObject* decode_datetime(PyObject* self, long long millis, const codec_o
(PyDateTime_DELTA_GET_MICROSECONDS(utcoffset) / 1000);
}
Py_DECREF(utcoffset);
utcoffset = PyObject_CallMethodObjArgs(options->tzinfo, state->_utcoffset_str, state->max_datetime, NULL);
utcoffset_args[1] = state->max_datetime;
utcoffset = PyObject_VectorcallMethod(
state->_utcoffset_str, utcoffset_args, 2, NULL);
if (utcoffset == NULL) {
return 0;
}
@ -481,7 +486,9 @@ static PyObject* decode_datetime(PyObject* self, long long millis, const codec_o
/* convert to local time */
if (options->tzinfo != Py_None) {
PyObject* temp = PyObject_CallMethodObjArgs(value, state->_astimezone_str, options->tzinfo, NULL);
PyObject* astimezone_args[2] = {value, options->tzinfo};
PyObject* temp = PyObject_VectorcallMethod(
state->_astimezone_str, astimezone_args, 2, NULL);
Py_DECREF(value);
value = temp;
}
@ -688,7 +695,8 @@ static int _load_python_objects(PyObject* module) {
return 1;
}
compiled = PyObject_CallFunction(re_compile, "O", empty_string);
PyObject* compile_args[1] = {empty_string};
compiled = PyObject_Vectorcall(re_compile, compile_args, 1, NULL);
Py_DECREF(re_compile);
if (compiled == NULL) {
state->REType = NULL;
@ -711,13 +719,19 @@ static long _type_marker(PyObject* object, PyObject* _type_marker_str) {
PyObject* type_marker = NULL;
long type = 0;
if (PyObject_HasAttr(object, _type_marker_str)) {
type_marker = PyObject_GetAttr(object, _type_marker_str);
if (type_marker == NULL) {
#if PY_VERSION_HEX >= 0x030D0000
// 3.13
if (PyObject_GetOptionalAttr(object, _type_marker_str, &type_marker) == -1) {
return -1;
}
}
# else
if (PyObject_HasAttr(object, _type_marker_str)) {
type_marker = PyObject_GetAttr(object, _type_marker_str);
if (type_marker == NULL) {
return -1;
}
}
#endif
/*
* Python objects with broken __getattr__ implementations could return
* arbitrary types for a call to PyObject_GetAttrString. For example
@ -814,6 +828,7 @@ int convert_codec_options(PyObject* self, PyObject* options_obj, codec_options_t
}
options->is_raw_bson = (101 == type_marker);
options->is_dict_class = (options->document_class == (PyObject*)&PyDict_Type);
options->options_obj = options_obj;
Py_INCREF(options->options_obj);
@ -1013,10 +1028,20 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
}
/*
* Use _type_marker attribute instead of PyObject_IsInstance for better perf.
*
* Skip _type_marker lookup for common built-in types
* that we know don't have a _type_marker attribute. This avoids the overhead
* of PyObject_HasAttr/PyObject_GetAttr calls for the most common cases.
*/
type = _type_marker(value, state->_type_marker_str);
if (type < 0) {
return 0;
if (PyUnicode_CheckExact(value) || PyLong_CheckExact(value) || PyFloat_CheckExact(value) ||
PyBool_Check(value) || PyDict_CheckExact(value) || PyList_CheckExact(value) ||
PyTuple_CheckExact(value) || PyBytes_CheckExact(value) || value == Py_None) {
type = 0;
} else {
type = _type_marker(value, state->_type_marker_str);
if (type < 0) {
return 0;
}
}
switch (type) {
@ -1227,7 +1252,9 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
case 100:
{
/* DBRef */
PyObject* as_doc = PyObject_CallMethodObjArgs(value, state->_as_doc_str, NULL);
PyObject* as_doc_args[1] = {value};
PyObject* as_doc = PyObject_VectorcallMethod(
state->_as_doc_str, as_doc_args, 1, NULL);
if (!as_doc) {
return 0;
}
@ -1383,7 +1410,9 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
return write_unicode(buffer, value);
} else if (PyDateTime_Check(value)) {
long long millis;
PyObject* utcoffset = PyObject_CallMethodObjArgs(value, state->_utcoffset_str , NULL);
PyObject* utcoffset_args[1] = {value};
PyObject* utcoffset = PyObject_VectorcallMethod(
state->_utcoffset_str, utcoffset_args, 1, NULL);
if (utcoffset == NULL)
return 0;
if (utcoffset != Py_None) {
@ -1422,7 +1451,9 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
if (!(uuid_rep_obj = PyLong_FromLong(options->uuid_rep))) {
return 0;
}
binary_value = PyObject_CallMethodObjArgs(state->Binary, state->_from_uuid_str, value, uuid_rep_obj, NULL);
PyObject* from_uuid_args[3] = {state->Binary, value, uuid_rep_obj};
binary_value = PyObject_VectorcallMethod(
state->_from_uuid_str, from_uuid_args, 3, NULL);
Py_DECREF(uuid_rep_obj);
if (binary_value == NULL) {
@ -1452,7 +1483,8 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
if (converter != NULL) {
/* Transform types that have a registered converter.
* A new reference is created upon transformation. */
new_value = PyObject_CallFunctionObjArgs(converter, value, NULL);
PyObject* converter_args[1] = {value};
new_value = PyObject_Vectorcall(converter, converter_args, 1, NULL);
if (new_value == NULL) {
return 0;
}
@ -1466,8 +1498,9 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
/* Try the fallback encoder if one is provided and we have not already
* attempted to use the fallback encoder. */
if (!in_fallback_call && options->type_registry.has_fallback_encoder) {
new_value = PyObject_CallFunctionObjArgs(
options->type_registry.fallback_encoder, value, NULL);
PyObject* fallback_args[1] = {value};
new_value = PyObject_Vectorcall(
options->type_registry.fallback_encoder, fallback_args, 1, NULL);
if (new_value == NULL) {
// propagate any exception raised by the callback
return 0;
@ -1668,7 +1701,8 @@ void handle_invalid_doc_error(PyObject* dict) {
goto cleanup;
}
// Add doc to the error instance as a property.
new_evalue = PyObject_CallFunctionObjArgs(InvalidDocument, new_msg, dict, NULL);
PyObject* exc_args[2] = {new_msg, dict};
new_evalue = PyObject_Vectorcall(InvalidDocument, exc_args, 2, NULL);
Py_DECREF(evalue);
Py_DECREF(etype);
etype = InvalidDocument;
@ -1944,7 +1978,8 @@ static PyObject *_dbref_hook(PyObject* self, PyObject* value) {
PyMapping_DelItem(value, state->_dollar_db_str);
}
ret = PyObject_CallFunctionObjArgs(state->DBRef, ref, id, database, value, NULL);
PyObject* dbref_args[4] = {ref, id, database, value};
ret = PyObject_Vectorcall(state->DBRef, dbref_args, 4, NULL);
Py_DECREF(value);
} else {
ret = value;
@ -2160,7 +2195,13 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
goto uuiderror;
}
binary_value = PyObject_CallFunction(state->Binary, "(Oi)", data, subtype);
PyObject* subtype_obj = PyLong_FromLong(subtype);
if (!subtype_obj) {
goto uuiderror;
}
PyObject* binary_args[2] = {data, subtype_obj};
binary_value = PyObject_Vectorcall(state->Binary, binary_args, 2, NULL);
Py_DECREF(subtype_obj);
if (binary_value == NULL) {
goto uuiderror;
}
@ -2175,7 +2216,9 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
if (!uuid_rep_obj) {
goto uuiderror;
}
value = PyObject_CallMethodObjArgs(binary_value, state->_as_uuid_str, uuid_rep_obj, NULL);
PyObject* as_uuid_args[2] = {binary_value, uuid_rep_obj};
value = PyObject_VectorcallMethod(
state->_as_uuid_str, as_uuid_args, 2, NULL);
Py_DECREF(uuid_rep_obj);
}
@ -2194,7 +2237,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
Py_DECREF(data);
goto invalid;
}
value = PyObject_CallFunctionObjArgs(state->Binary, data, st, NULL);
PyObject* binary_args[2] = {data, st};
value = PyObject_Vectorcall(state->Binary, binary_args, 2, NULL);
Py_DECREF(st);
Py_DECREF(data);
if (!value) {
@ -2215,7 +2259,13 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
if (max < 12) {
goto invalid;
}
value = PyObject_CallFunction(state->ObjectId, "y#", buffer + *position, (Py_ssize_t)12);
PyObject* oid_bytes = PyBytes_FromStringAndSize(buffer + *position, 12);
if (!oid_bytes) {
goto invalid;
}
PyObject* oid_args[1] = {oid_bytes};
value = PyObject_Vectorcall(state->ObjectId, oid_args, 1, NULL);
Py_DECREF(oid_bytes);
*position += 12;
break;
}
@ -2294,7 +2344,14 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
}
*position += (unsigned)flags_length + 1;
value = PyObject_CallFunction(state->Regex, "Oi", pattern, flags);
PyObject* flags_obj = PyLong_FromLong(flags);
if (!flags_obj) {
Py_DECREF(pattern);
goto invalid;
}
PyObject* regex_args[2] = {pattern, flags_obj};
value = PyObject_Vectorcall(state->Regex, regex_args, 2, NULL);
Py_DECREF(flags_obj);
Py_DECREF(pattern);
break;
}
@ -2327,13 +2384,21 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
}
*position += coll_length;
id = PyObject_CallFunction(state->ObjectId, "y#", buffer + *position, (Py_ssize_t)12);
PyObject* oid_bytes = PyBytes_FromStringAndSize(buffer + *position, 12);
if (!oid_bytes) {
Py_DECREF(collection);
goto invalid;
}
PyObject* oid_args[1] = {oid_bytes};
id = PyObject_Vectorcall(state->ObjectId, oid_args, 1, NULL);
Py_DECREF(oid_bytes);
if (!id) {
Py_DECREF(collection);
goto invalid;
}
*position += 12;
value = PyObject_CallFunctionObjArgs(state->DBRef, collection, id, NULL);
PyObject* dbref_args[2] = {collection, id};
value = PyObject_Vectorcall(state->DBRef, dbref_args, 2, NULL);
Py_DECREF(collection);
Py_DECREF(id);
break;
@ -2363,7 +2428,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
goto invalid;
}
*position += value_length;
value = PyObject_CallFunctionObjArgs(state->Code, code, NULL, NULL);
PyObject* code_args[1] = {code};
value = PyObject_Vectorcall(state->Code, code_args, 1, NULL);
Py_DECREF(code);
break;
}
@ -2429,7 +2495,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
}
*position += scope_size;
value = PyObject_CallFunctionObjArgs(state->Code, code, scope, NULL);
PyObject* code_scope_args[2] = {code, scope};
value = PyObject_Vectorcall(state->Code, code_scope_args, 2, NULL);
Py_DECREF(code);
Py_DECREF(scope);
break;
@ -2459,7 +2526,19 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
memcpy(&time, buffer + *position + 4, 4);
inc = BSON_UINT32_FROM_LE(inc);
time = BSON_UINT32_FROM_LE(time);
value = PyObject_CallFunction(state->Timestamp, "II", time, inc);
PyObject* time_obj = PyLong_FromUnsignedLong(time);
if (!time_obj) {
goto invalid;
}
PyObject* inc_obj = PyLong_FromUnsignedLong(inc);
if (!inc_obj) {
Py_DECREF(time_obj);
goto invalid;
}
PyObject* ts_args[2] = {time_obj, inc_obj};
value = PyObject_Vectorcall(state->Timestamp, ts_args, 2, NULL);
Py_DECREF(time_obj);
Py_DECREF(inc_obj);
*position += 8;
break;
}
@ -2471,7 +2550,13 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
}
memcpy(&ll, buffer + *position, 8);
ll = (int64_t)BSON_UINT64_FROM_LE(ll);
value = PyObject_CallFunction(state->BSONInt64, "L", ll);
PyObject* ll_obj = PyLong_FromLongLong(ll);
if (!ll_obj) {
goto invalid;
}
PyObject* int64_args[1] = {ll_obj};
value = PyObject_Vectorcall(state->BSONInt64, int64_args, 1, NULL);
Py_DECREF(ll_obj);
*position += 8;
break;
}
@ -2484,19 +2569,21 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
if (!_bytes_obj) {
goto invalid;
}
value = PyObject_CallMethodObjArgs(state->Decimal128, state->_from_bid_str, _bytes_obj, NULL);
PyObject* dec128_args[2] = {state->Decimal128, _bytes_obj};
value = PyObject_VectorcallMethod(
state->_from_bid_str, dec128_args, 2, NULL);
Py_DECREF(_bytes_obj);
*position += 16;
break;
}
case 255:
{
value = PyObject_CallFunctionObjArgs(state->MinKey, NULL);
value = PyObject_Vectorcall(state->MinKey, NULL, 0, NULL);
break;
}
case 127:
{
value = PyObject_CallFunctionObjArgs(state->MaxKey, NULL);
value = PyObject_Vectorcall(state->MaxKey, NULL, 0, NULL);
break;
}
default:
@ -2548,7 +2635,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
}
converter = PyDict_GetItem(options->type_registry.decoder_map, value_type);
if (converter != NULL) {
PyObject* new_value = PyObject_CallFunctionObjArgs(converter, value, NULL);
PyObject* converter_args[1] = {value};
PyObject* new_value = PyObject_Vectorcall(converter, converter_args, 1, NULL);
Py_DECREF(value_type);
Py_DECREF(value);
return new_value;
@ -2716,11 +2804,20 @@ static PyObject* _elements_to_dict(PyObject* self, const char* string,
unsigned max,
const codec_options_t* options) {
unsigned position = 0;
PyObject* dict = PyObject_CallObject(options->document_class, NULL);
PyObject* dict;
int raw_array = 0;
/* Use PyDict_New() directly when document_class is dict.
* This avoids the overhead of PyObject_CallObject() for the common case. */
if (options->is_dict_class) {
dict = PyDict_New();
} else {
dict = PyObject_CallObject(options->document_class, NULL);
}
if (!dict) {
return NULL;
}
int raw_array = 0;
while (position < max) {
PyObject* name = NULL;
PyObject* value = NULL;
@ -2735,7 +2832,24 @@ static PyObject* _elements_to_dict(PyObject* self, const char* string,
position = (unsigned)new_position;
}
PyObject_SetItem(dict, name, value);
/* Use PyDict_SetItem() when document_class is dict.
* PyDict_SetItem() is faster than PyObject_SetItem() because it
* avoids method lookup overhead. */
if (options->is_dict_class) {
if (PyDict_SetItem(dict, name, value) < 0) {
Py_DECREF(name);
Py_DECREF(value);
Py_DECREF(dict);
return NULL;
}
} else {
if (PyObject_SetItem(dict, name, value) < 0) {
Py_DECREF(name);
Py_DECREF(value);
Py_DECREF(dict);
return NULL;
}
}
Py_DECREF(name);
Py_DECREF(value);
}
@ -2747,9 +2861,14 @@ static PyObject* elements_to_dict(PyObject* self, const char* string,
const codec_options_t* options) {
PyObject* result;
if (options->is_raw_bson) {
return PyObject_CallFunction(
options->document_class, "y#O",
string, max, options->options_obj);
PyObject* bson_bytes = PyBytes_FromStringAndSize(string, max);
if (!bson_bytes) {
return NULL;
}
PyObject* raw_args[2] = {bson_bytes, options->options_obj};
result = PyObject_Vectorcall(options->document_class, raw_args, 2, NULL);
Py_DECREF(bson_bytes);
return result;
}
if (Py_EnterRecursiveCall(" while decoding a BSON document"))
return NULL;

View File

@ -72,6 +72,7 @@ typedef struct codec_options_t {
unsigned char datetime_conversion;
PyObject* options_obj;
unsigned char is_raw_bson;
unsigned char is_dict_class;
} codec_options_t;
/* C API functions */

View File

@ -22,6 +22,7 @@ from __future__ import annotations
import copy
import re
import warnings
from collections.abc import Mapping as _Mapping
from typing import (
Any,
@ -99,13 +100,28 @@ class SON(Dict[_Key, _Value]):
yield from self.__keys
def has_key(self, key: _Key) -> bool:
warnings.warn(
"SON.has_key() is deprecated, use the in operator instead",
DeprecationWarning,
stacklevel=2,
)
return key in self.__keys
def iterkeys(self) -> Iterator[_Key]:
warnings.warn(
"SON.iterkeys() is deprecated, use the keys() method instead",
DeprecationWarning,
stacklevel=2,
)
return self.__iter__()
# fourth level uses definitions from lower levels
def itervalues(self) -> Iterator[_Value]:
warnings.warn(
"SON.itervalues() is deprecated, use the values() method instead",
DeprecationWarning,
stacklevel=2,
)
for _, v in self.items():
yield v

View File

@ -1,6 +1,20 @@
Changelog
=========
Changes in Version 4.17.0 (2026/XX/XX)
--------------------------------------
PyMongo 4.17 brings a number of changes including:
- ``has_key``, ``iterkeys`` and ``itervalues`` in :class:`bson.son.SON` have
been deprecated and will be removed in PyMongo 5.0. These methods were
deprecated in favor of the standard dictionary containment operator ``in``
and the ``keys()`` and ``values()`` methods, respectively.
- Added the :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.bind` and :meth:`~pymongo.client_session.ClientSession.bind` methods
that allow users to bind a session to all database operations within the scope of a context manager instead of having to explicitly pass the session to each individual operation.
See <PLACEHOLDER> for examples and more information.
Changes in Version 4.16.0 (2026/01/07)
--------------------------------------

View File

@ -57,7 +57,9 @@ lint-manual *args="": && resync
[group('test')]
test *args="-v --durations=5 --maxfail=10": && resync
uv run --extra test python -m pytest {{args}}
#!/usr/bin/env bash
set -euo pipefail
uv run ${USE_ACTIVE_VENV:+--active} --extra test python -m pytest {{args}}
[group('test')]
test-numpy *args="": && resync
@ -80,6 +82,25 @@ teardown-tests:
integration-tests:
bash integration_tests/run.sh
[group('test')]
test-coverage *args="":
just setup-tests --cov
just run-tests {{args}}
[group('coverage')]
coverage-report:
uv tool run --with "coverage[toml]" coverage report
[group('coverage')]
coverage-html:
uv tool run --with "coverage[toml]" coverage html
@echo "Coverage report generated in htmlcov/index.html"
[group('coverage')]
coverage-xml:
uv tool run --with "coverage[toml]" coverage xml
@echo "Coverage report generated in coverage.xml"
[group('server')]
run-server *args="":
bash .evergreen/scripts/run-server.sh {{args}}

View File

@ -139,6 +139,7 @@ import collections
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar, Token
from typing import (
TYPE_CHECKING,
Any,
@ -181,6 +182,28 @@ if TYPE_CHECKING:
_IS_SYNC = False
_SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None)
class _AsyncBoundSessionContext:
"""Context manager returned by AsyncClientSession.bind() that manages bound state."""
def __init__(self, session: AsyncClientSession, end_session: bool) -> None:
self._session = session
self._session_token: Optional[Token[AsyncClientSession]] = None
self._end_session = end_session
async def __aenter__(self) -> AsyncClientSession:
self._session_token = _SESSION.set(self._session) # type: ignore[assignment]
return self._session
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._session_token:
_SESSION.reset(self._session_token) # type: ignore[arg-type]
self._session_token = None
if self._end_session:
await self._session.end_session()
class SessionOptions:
"""Options for a new :class:`AsyncClientSession`.
@ -547,6 +570,24 @@ class AsyncClientSession:
if self._server_session is None:
raise InvalidOperation("Cannot use ended session")
def bind(self, end_session: bool = True) -> _AsyncBoundSessionContext:
"""Bind this session so it is implicitly passed to all database operations within the returned context.
.. code-block:: python
async with client.start_session() as s:
async with s.bind():
# session=s is passed implicitly
await client.db.collection.insert_one({"x": 1})
:param end_session: Whether to end the session on exiting the returned context. Defaults to True.
If set to False, :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.end_session()` must be called
once the session is no longer used.
.. versionadded:: 4.17
"""
return _AsyncBoundSessionContext(self, end_session)
async def __aenter__(self) -> AsyncClientSession:
return self

View File

@ -65,7 +65,7 @@ from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.asynchronous import client_session, database, uri_parser
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _EmptyServerSession
from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
@ -1408,7 +1408,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
def _ensure_session(
self, session: Optional[AsyncClientSession] = None
) -> Optional[AsyncClientSession]:
"""If provided session is None, lend a temporary session."""
"""If provided session and bound session are None, lend a temporary session."""
session = session or self._get_bound_session()
if session:
return session
@ -2267,11 +2268,14 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self, session: Optional[client_session.AsyncClientSession]
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]:
"""If provided session is None, lend a temporary session."""
if session is not None:
if not isinstance(session, client_session.AsyncClientSession):
raise ValueError(
f"'session' argument must be an AsyncClientSession or None, not {type(session)}"
)
if session is not None and not isinstance(session, client_session.AsyncClientSession):
raise ValueError(
f"'session' argument must be an AsyncClientSession or None, not {type(session)}"
)
# Check for a bound session. If one exists, treat it as an explicitly passed session.
session = session or self._get_bound_session()
if session:
# Don't call end_session.
yield session
return
@ -2301,6 +2305,18 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
if session is not None:
session._process_response(reply)
def _get_bound_session(self) -> Optional[AsyncClientSession]:
bound_session = _SESSION.get()
if bound_session:
if bound_session.client is self:
return bound_session
else:
raise InvalidOperation(
"Only the client that created the bound session can perform operations within its context block. See <PLACEHOLDER> for more information."
)
else:
return None
async def server_info(
self, session: Optional[client_session.AsyncClientSession] = None
) -> dict[str, Any]:

View File

@ -233,13 +233,6 @@ def validate_readable(option: str, value: Any) -> Optional[str]:
return value
def validate_positive_integer_or_none(option: str, value: Any) -> Optional[int]:
"""Validate that 'value' is a positive integer or None."""
if value is None:
return value
return validate_positive_integer(option, value)
def validate_non_negative_integer_or_none(option: str, value: Any) -> Optional[int]:
"""Validate that 'value' is a positive integer or 0 or None."""
if value is None:
@ -261,20 +254,6 @@ def validate_string_or_none(option: str, value: Any) -> Optional[str]:
return validate_string(option, value)
def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]:
"""Validates that 'value' is an integer or string."""
if isinstance(value, int):
return value
elif isinstance(value, str):
try:
return int(value)
except ValueError:
return value
raise TypeError(
f"Wrong type for {option}, value must be an integer or a string, not {type(value)}"
)
def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]:
"""Validates that 'value' is an integer or string."""
if isinstance(value, int):
@ -817,16 +796,6 @@ TIMEOUT_OPTIONS: list[str] = [
"waitqueuetimeoutms",
]
_AUTH_OPTIONS = frozenset(["authmechanismproperties"])
def validate_auth_option(option: str, value: Any) -> tuple[str, Any]:
"""Validate optional authentication parameters."""
lower, value = validate(option, value)
if lower not in _AUTH_OPTIONS:
raise ConfigurationError(f"Unknown option: {option}. Must be in {_AUTH_OPTIONS}")
return option, value
def _get_validator(
key: str, validators: dict[str, Callable[[Any, Any], Any]], normed_key: Optional[str] = None

View File

@ -139,6 +139,7 @@ import collections
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar, Token
from typing import (
TYPE_CHECKING,
Any,
@ -180,6 +181,28 @@ if TYPE_CHECKING:
_IS_SYNC = True
_SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None)
class _BoundSessionContext:
"""Context manager returned by ClientSession.bind() that manages bound state."""
def __init__(self, session: ClientSession, end_session: bool) -> None:
self._session = session
self._session_token: Optional[Token[ClientSession]] = None
self._end_session = end_session
def __enter__(self) -> ClientSession:
self._session_token = _SESSION.set(self._session) # type: ignore[assignment]
return self._session
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._session_token:
_SESSION.reset(self._session_token) # type: ignore[arg-type]
self._session_token = None
if self._end_session:
self._session.end_session()
class SessionOptions:
"""Options for a new :class:`ClientSession`.
@ -546,6 +569,24 @@ class ClientSession:
if self._server_session is None:
raise InvalidOperation("Cannot use ended session")
def bind(self, end_session: bool = True) -> _BoundSessionContext:
"""Bind this session so it is implicitly passed to all database operations within the returned context.
.. code-block:: python
with client.start_session() as s:
with s.bind():
# session=s is passed implicitly
client.db.collection.insert_one({"x": 1})
:param end_session: Whether to end the session on exiting the returned context. Defaults to True.
If set to False, :meth:`~pymongo.client_session.ClientSession.end_session()` must be called
once the session is no longer used.
.. versionadded:: 4.17
"""
return _BoundSessionContext(self, end_session)
def __enter__(self) -> ClientSession:
return self

View File

@ -108,7 +108,7 @@ from pymongo.server_type import SERVER_TYPE
from pymongo.synchronous import client_session, database, uri_parser
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.client_session import _EmptyServerSession
from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology, _ErrorContext
@ -1406,7 +1406,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
)
def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]:
"""If provided session is None, lend a temporary session."""
"""If provided session and bound session are None, lend a temporary session."""
session = session or self._get_bound_session()
if session:
return session
@ -2263,11 +2264,14 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
self, session: Optional[client_session.ClientSession]
) -> Generator[Optional[client_session.ClientSession], None]:
"""If provided session is None, lend a temporary session."""
if session is not None:
if not isinstance(session, client_session.ClientSession):
raise ValueError(
f"'session' argument must be a ClientSession or None, not {type(session)}"
)
if session is not None and not isinstance(session, client_session.ClientSession):
raise ValueError(
f"'session' argument must be a ClientSession or None, not {type(session)}"
)
# Check for a bound session. If one exists, treat it as an explicitly passed session.
session = session or self._get_bound_session()
if session:
# Don't call end_session.
yield session
return
@ -2295,6 +2299,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
if session is not None:
session._process_response(reply)
def _get_bound_session(self) -> Optional[ClientSession]:
bound_session = _SESSION.get()
if bound_session:
if bound_session.client is self:
return bound_session
else:
raise InvalidOperation(
"Only the client that created the bound session can perform operations within its context block. See <PLACEHOLDER> for more information."
)
else:
return None
def server_info(self, session: Optional[client_session.ClientSession] = None) -> dict[str, Any]:
"""Get information about the MongoDB server we're connected to.

View File

@ -189,6 +189,52 @@ class TestSession(AsyncIntegrationTest):
f"{f.__name__} did not return implicit session to pool",
)
# Explicit bound session
for f, args, kw in ops:
async with client.start_session() as s:
async with s.bind():
listener.reset()
s._materialize()
last_use = s._server_session.last_use
start = time.monotonic()
self.assertLessEqual(last_use, start)
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
await f(*args, **kw)
self.assertGreaterEqual(len(listener.started_events), 1)
for event in listener.started_events:
self.assertIn(
"lsid",
event.command,
f"{f.__name__} sent no lsid with {event.command_name}",
)
self.assertEqual(
s.session_id,
event.command["lsid"],
f"{f.__name__} sent wrong lsid with {event.command_name}",
)
self.assertFalse(s.has_ended)
self.assertTrue(s.has_ended)
with self.assertRaisesRegex(InvalidOperation, "ended session"):
async with s.bind():
await f(*args, **kw)
# Test a session cannot be used on another client.
async with self.client2.start_session() as s:
async with s.bind():
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
with self.assertRaisesRegex(
InvalidOperation,
"Only the client that created the bound session can perform operations within its context block",
):
await f(*args, **kw)
async def test_implicit_sessions_checkout(self):
# "To confirm that implicit sessions only allocate their server session after a
# successful connection checkout" test from Driver Sessions Spec.
@ -825,6 +871,73 @@ class TestSession(AsyncIntegrationTest):
async with client.start_session() as s:
self.assertRaises(TypeError, lambda: copy.copy(s))
async def test_nested_session_binding(self):
coll = self.client.pymongo_test.test
await coll.insert_one({"x": 1})
session1 = self.client.start_session()
session2 = self.client.start_session()
session1._materialize()
session2._materialize()
try:
self.listener.reset()
# Uses implicit session
await coll.find_one()
implicit_lsid = self.listener.started_events[0].command.get("lsid")
self.assertIsNotNone(implicit_lsid)
self.assertNotEqual(implicit_lsid, session1.session_id)
self.assertNotEqual(implicit_lsid, session2.session_id)
async with session1.bind(end_session=False):
self.listener.reset()
# Uses bound session1
await coll.find_one()
session1_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session1_lsid, session1.session_id)
async with session2.bind(end_session=False):
self.listener.reset()
# Uses bound session2
await coll.find_one()
session2_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session2_lsid, session2.session_id)
self.assertNotEqual(session2_lsid, session1.session_id)
self.listener.reset()
# Use bound session1 again
await coll.find_one()
session1_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session1_lsid, session1.session_id)
self.assertNotEqual(session1_lsid, session2.session_id)
self.listener.reset()
# Uses implicit session
await coll.find_one()
implicit_lsid = self.listener.started_events[0].command.get("lsid")
self.assertIsNotNone(implicit_lsid)
self.assertNotEqual(implicit_lsid, session1.session_id)
self.assertNotEqual(implicit_lsid, session2.session_id)
finally:
await session1.end_session()
await session2.end_session()
async def test_session_binding_end_session(self):
coll = self.client.pymongo_test.test
await coll.insert_one({"x": 1})
async with self.client.start_session().bind() as s1:
await coll.find_one()
self.assertTrue(s1.has_ended)
async with self.client.start_session().bind(end_session=False) as s2:
await coll.find_one()
self.assertFalse(s2.has_ended)
await s2.end_session()
class TestCausalConsistency(AsyncUnitTest):
listener: SessionTestListener

View File

@ -1464,11 +1464,6 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
self.assertListEqual(sorted_expected_documents, actual_documents)
async def run_scenario(self, spec, uri=None):
# Kill all sessions before and after each test to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
await self.kill_all_sessions()
# Handle flaky tests.
flaky_tests = [
("PYTHON-5170", ".*test_discovery_and_monitoring.*"),
@ -1504,6 +1499,15 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
if skip_reason is not None:
raise unittest.SkipTest(f"{skip_reason}")
# Kill all sessions after each test with transactions to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
for op in spec["operations"]:
name = op["name"]
if name == "startTransaction" or name == "withTransaction":
self.addAsyncCleanup(self.kill_all_sessions)
break
# process createEntities
self._uri = uri
self.entity_map = EntityMapUtil(self)

View File

@ -16,43 +16,13 @@
from __future__ import annotations
import asyncio
import functools
import os
import time
import unittest
from collections import abc
from inspect import iscoroutinefunction
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
from test.asynchronous import async_client_context
from test.asynchronous.helpers import ConcurrentRunner
from test.utils_shared import (
CMAPListener,
CompareType,
EventListener,
OvertCommandListener,
ScenarioDict,
ServerAndTopologyEventListener,
camel_to_snake,
camel_to_snake_args,
parse_spec_options,
prepare_spec_arguments,
)
from typing import List
from test.utils_shared import ScenarioDict
from bson import ObjectId, decode, encode, json_util
from bson.binary import Binary
from bson.int64 import Int64
from bson.son import SON
from gridfs import GridFSBucket
from gridfs.asynchronous.grid_file import AsyncGridFSBucket
from pymongo.asynchronous import client_session
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
from bson import json_util
from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult, _WriteResult
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
@ -219,597 +189,3 @@ class AsyncSpecTestCreator:
self._create_tests()
else:
asyncio.run(self._create_tests())
class AsyncSpecRunner(AsyncIntegrationTest):
mongos_clients: List
knobs: client_knobs
listener: EventListener
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.mongos_clients = []
# Speed up the tests by decreasing the heartbeat frequency.
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
self.knobs.enable()
self.targets = {}
self.listener = None # type: ignore
self.pool_listener = None
self.server_listener = None
self.maxDiff = None
async def asyncTearDown(self) -> None:
self.knobs.disable()
async def set_fail_point(self, command_args):
clients = self.mongos_clients if self.mongos_clients else [self.client]
for client in clients:
await self.configure_fail_point(client, command_args)
async def targeted_fail_point(self, session, fail_point):
"""Run the targetedFailPoint test operation.
Enable the fail point on the session's pinned mongos.
"""
clients = {c.address: c for c in self.mongos_clients}
client = clients[session._pinned_address]
await self.configure_fail_point(client, fail_point)
self.addAsyncCleanup(self.set_fail_point, {"mode": "off"})
def assert_session_pinned(self, session):
"""Run the assertSessionPinned test operation.
Assert that the given session is pinned.
"""
self.assertIsNotNone(session._transaction.pinned_address)
def assert_session_unpinned(self, session):
"""Run the assertSessionUnpinned test operation.
Assert that the given session is not pinned.
"""
self.assertIsNone(session._pinned_address)
self.assertIsNone(session._transaction.pinned_address)
async def assert_collection_exists(self, database, collection):
"""Run the assertCollectionExists test operation."""
db = self.client[database]
self.assertIn(collection, await db.list_collection_names())
async def assert_collection_not_exists(self, database, collection):
"""Run the assertCollectionNotExists test operation."""
db = self.client[database]
self.assertNotIn(collection, await db.list_collection_names())
async def assert_index_exists(self, database, collection, index):
"""Run the assertIndexExists test operation."""
coll = self.client[database][collection]
self.assertIn(index, [doc["name"] async for doc in await coll.list_indexes()])
async def assert_index_not_exists(self, database, collection, index):
"""Run the assertIndexNotExists test operation."""
coll = self.client[database][collection]
self.assertNotIn(index, [doc["name"] async for doc in await coll.list_indexes()])
async def wait(self, ms):
"""Run the "wait" test operation."""
await asyncio.sleep(ms / 1000.0)
def assertErrorLabelsContain(self, exc, expected_labels):
labels = [l for l in expected_labels if exc.has_error_label(l)]
self.assertEqual(labels, expected_labels)
def assertErrorLabelsOmit(self, exc, omit_labels):
for label in omit_labels:
self.assertFalse(
exc.has_error_label(label), msg=f"error labels should not contain {label}"
)
async def kill_all_sessions(self):
clients = self.mongos_clients if self.mongos_clients else [self.client]
for client in clients:
try:
await client.admin.command("killAllSessions", [])
except (OperationFailure, AutoReconnect):
# "operation was interrupted" by killing the command's
# own session.
# On 8.0+ killAllSessions sometimes returns a network error.
pass
def check_command_result(self, expected_result, result):
# Only compare the keys in the expected result.
filtered_result = {}
for key in expected_result:
try:
filtered_result[key] = result[key]
except KeyError:
pass
self.assertEqual(filtered_result, expected_result)
# TODO: factor the following function with test_crud.py.
def check_result(self, expected_result, result):
if isinstance(result, _WriteResult):
for res in expected_result:
prop = camel_to_snake(res)
# SPEC-869: Only BulkWriteResult has upserted_count.
if prop == "upserted_count" and not isinstance(result, BulkWriteResult):
if result.upserted_id is not None:
upserted_count = 1
else:
upserted_count = 0
self.assertEqual(upserted_count, expected_result[res], prop)
elif prop == "inserted_ids":
# BulkWriteResult does not have inserted_ids.
if isinstance(result, BulkWriteResult):
self.assertEqual(len(expected_result[res]), result.inserted_count)
else:
# InsertManyResult may be compared to [id1] from the
# crud spec or {"0": id1} from the retryable write spec.
ids = expected_result[res]
if isinstance(ids, dict):
ids = [ids[str(i)] for i in range(len(ids))]
self.assertEqual(ids, result.inserted_ids, prop)
elif prop == "upserted_ids":
# Convert indexes from strings to integers.
ids = expected_result[res]
expected_ids = {}
for str_index in ids:
expected_ids[int(str_index)] = ids[str_index]
self.assertEqual(expected_ids, result.upserted_ids, prop)
else:
self.assertEqual(getattr(result, prop), expected_result[res], prop)
return True
else:
def _helper(expected_result, result):
if isinstance(expected_result, abc.Mapping):
for i in expected_result.keys():
self.assertEqual(expected_result[i], result[i])
elif isinstance(expected_result, list):
for i, k in zip(expected_result, result):
_helper(i, k)
else:
self.assertEqual(expected_result, result)
_helper(expected_result, result)
return None
def get_object_name(self, op):
"""Allow subclasses to override handling of 'object'
Transaction spec says 'object' is required.
"""
return op["object"]
@staticmethod
def parse_options(opts):
return parse_spec_options(opts)
async def run_operation(self, sessions, collection, operation):
original_collection = collection
name = camel_to_snake(operation["name"])
if name == "run_command":
name = "command"
elif name == "download_by_name":
name = "open_download_stream_by_name"
elif name == "download":
name = "open_download_stream"
elif name == "map_reduce":
self.skipTest("PyMongo does not support mapReduce")
elif name == "count":
self.skipTest("PyMongo does not support count")
database = collection.database
collection = database.get_collection(collection.name)
if "collectionOptions" in operation:
collection = collection.with_options(
**self.parse_options(operation["collectionOptions"])
)
object_name = self.get_object_name(operation)
if object_name == "gridfsbucket":
# Only create the GridFSBucket when we need it (for the gridfs
# retryable reads tests).
obj = AsyncGridFSBucket(database, bucket_name=collection.name)
else:
objects = {
"client": database.client,
"database": database,
"collection": collection,
"testRunner": self,
}
objects.update(sessions)
obj = objects[object_name]
# Combine arguments with options and handle special cases.
arguments = operation.get("arguments", {})
arguments.update(arguments.pop("options", {}))
self.parse_options(arguments)
cmd = getattr(obj, name)
with_txn_callback = functools.partial(
self.run_operations, sessions, original_collection, in_with_transaction=True
)
prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback)
if name == "run_on_thread":
args = {"sessions": sessions, "collection": collection}
args.update(arguments)
arguments = args
if not _IS_SYNC and iscoroutinefunction(cmd):
result = await cmd(**dict(arguments))
else:
result = cmd(**dict(arguments))
# Cleanup open change stream cursors.
if name == "watch":
self.addAsyncCleanup(result.close)
if name == "aggregate":
if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]:
# Read from the primary to ensure causal consistency.
out = collection.database.get_collection(
arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY
)
return out.find()
if "download" in name:
result = Binary(result.read())
if isinstance(result, AsyncCursor) or isinstance(result, AsyncCommandCursor):
return await result.to_list()
return result
def allowable_errors(self, op):
"""Allow encryption spec to override expected error classes."""
return (PyMongoError,)
async def _run_op(self, sessions, collection, op, in_with_transaction):
expected_result = op.get("result")
if expect_error(op):
with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context:
await self.run_operation(sessions, collection, op.copy())
exc = context.exception
if expect_error_message(expected_result):
if isinstance(exc, BulkWriteError):
errmsg = str(exc.details).lower()
else:
errmsg = str(exc).lower()
self.assertIn(expected_result["errorContains"].lower(), errmsg)
if expect_error_code(expected_result):
self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName"))
if expect_error_labels_contain(expected_result):
self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"])
if expect_error_labels_omit(expected_result):
self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"])
if expect_timeout_error(expected_result):
self.assertIsInstance(exc, PyMongoError)
if not exc.timeout:
# Re-raise the exception for better diagnostics.
raise exc
# Reraise the exception if we're in the with_transaction
# callback.
if in_with_transaction:
raise context.exception
else:
result = await self.run_operation(sessions, collection, op.copy())
if "result" in op:
if op["name"] == "runCommand":
self.check_command_result(expected_result, result)
else:
self.check_result(expected_result, result)
async def run_operations(self, sessions, collection, ops, in_with_transaction=False):
for op in ops:
await self._run_op(sessions, collection, op, in_with_transaction)
# TODO: factor with test_command_monitoring.py
def check_events(self, test, listener, session_ids):
events = listener.started_events
if not len(test["expectations"]):
return
# Give a nicer message when there are missing or extra events
cmds = decode_raw([event.command for event in events])
self.assertEqual(len(events), len(test["expectations"]), cmds)
for i, expectation in enumerate(test["expectations"]):
event_type = next(iter(expectation))
event = events[i]
# The tests substitute 42 for any number other than 0.
if event.command_name == "getMore" and event.command["getMore"]:
event.command["getMore"] = Int64(42)
elif event.command_name == "killCursors":
event.command["cursors"] = [Int64(42)]
elif event.command_name == "update":
# TODO: remove this once PYTHON-1744 is done.
# Add upsert and multi fields back into expectations.
updates = expectation[event_type]["command"]["updates"]
for update in updates:
update.setdefault("upsert", False)
update.setdefault("multi", False)
# Replace afterClusterTime: 42 with actual afterClusterTime.
expected_cmd = expectation[event_type]["command"]
expected_read_concern = expected_cmd.get("readConcern")
if expected_read_concern is not None:
time = expected_read_concern.get("afterClusterTime")
if time == 42:
actual_time = event.command.get("readConcern", {}).get("afterClusterTime")
if actual_time is not None:
expected_read_concern["afterClusterTime"] = actual_time
recovery_token = expected_cmd.get("recoveryToken")
if recovery_token == 42:
expected_cmd["recoveryToken"] = CompareType(dict)
# Replace lsid with a name like "session0" to match test.
if "lsid" in event.command:
for name, lsid in session_ids.items():
if event.command["lsid"] == lsid:
event.command["lsid"] = name
break
for attr, expected in expectation[event_type].items():
actual = getattr(event, attr)
expected = wrap_types(expected)
if isinstance(expected, dict):
for key, val in expected.items():
if val is None:
if key in actual:
self.fail(f"Unexpected key [{key}] in {actual!r}")
elif key not in actual:
self.fail(f"Expected key [{key}] in {actual!r}")
else:
self.assertEqual(
val, decode_raw(actual[key]), f"Key [{key}] in {actual}"
)
else:
self.assertEqual(actual, expected)
def maybe_skip_scenario(self, test):
if test.get("skipReason"):
self.skipTest(test.get("skipReason"))
def get_scenario_db_name(self, scenario_def):
"""Allow subclasses to override a test's database name."""
return scenario_def["database_name"]
def get_scenario_coll_name(self, scenario_def):
"""Allow subclasses to override a test's collection name."""
return scenario_def["collection_name"]
def get_outcome_coll_name(self, outcome, collection):
"""Allow subclasses to override outcome collection."""
return collection.name
async def run_test_ops(self, sessions, collection, test):
"""Added to allow retryable writes spec to override a test's
operation.
"""
await self.run_operations(sessions, collection, test["operations"])
def parse_client_options(self, opts):
"""Allow encryption spec to override a clientOptions parsing."""
return opts
async def setup_scenario(self, scenario_def):
"""Allow specs to override a test's setup."""
db_name = self.get_scenario_db_name(scenario_def)
coll_name = self.get_scenario_coll_name(scenario_def)
documents = scenario_def["data"]
# Setup the collection with as few majority writes as possible.
db = async_client_context.client.get_database(db_name)
coll_exists = bool(await db.list_collection_names(filter={"name": coll_name}))
if coll_exists:
await db[coll_name].delete_many({})
# Only use majority wc only on the final write.
wc = WriteConcern(w="majority")
if documents:
db.get_collection(coll_name, write_concern=wc).insert_many(documents)
elif not coll_exists:
# Ensure collection exists.
await db.create_collection(coll_name, write_concern=wc)
async def run_scenario(self, scenario_def, test):
self.maybe_skip_scenario(test)
# Kill all sessions before and after each test to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
await self.kill_all_sessions()
self.addAsyncCleanup(self.kill_all_sessions)
await self.setup_scenario(scenario_def)
database_name = self.get_scenario_db_name(scenario_def)
collection_name = self.get_scenario_coll_name(scenario_def)
# SPEC-1245 workaround StaleDbVersion on distinct
for c in self.mongos_clients:
await c[database_name][collection_name].distinct("x")
# Configure the fail point before creating the client.
if "failPoint" in test:
fp = test["failPoint"]
await self.set_fail_point(fp)
self.addAsyncCleanup(
self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"}
)
listener = OvertCommandListener()
pool_listener = CMAPListener()
server_listener = ServerAndTopologyEventListener()
# Create a new client, to avoid interference from pooled sessions.
client_options = self.parse_client_options(test["clientOptions"])
use_multi_mongos = test["useMultipleMongoses"]
host = None
if use_multi_mongos:
if async_client_context.load_balancer:
host = async_client_context.MULTI_MONGOS_LB_URI
elif async_client_context.is_mongos:
host = async_client_context.mongos_seeds()
client = await self.async_rs_client(
h=host, event_listeners=[listener, pool_listener, server_listener], **client_options
)
self.scenario_client = client
self.listener = listener
self.pool_listener = pool_listener
self.server_listener = server_listener
# Create session0 and session1.
sessions = {}
session_ids = {}
for i in range(2):
# Don't attempt to create sessions if they are not supported by
# the running server version.
if not async_client_context.sessions_enabled:
break
session_name = "session%d" % i
opts = camel_to_snake_args(test["sessionOptions"][session_name])
if "default_transaction_options" in opts:
txn_opts = self.parse_options(opts["default_transaction_options"])
txn_opts = client_session.TransactionOptions(**txn_opts)
opts["default_transaction_options"] = txn_opts
s = client.start_session(**dict(opts))
sessions[session_name] = s
# Store lsid so we can access it after end_session, in check_events.
session_ids[session_name] = s.session_id
self.addAsyncCleanup(end_sessions, sessions)
collection = client[database_name][collection_name]
await self.run_test_ops(sessions, collection, test)
await end_sessions(sessions)
self.check_events(test, listener, session_ids)
# Disable fail points.
if "failPoint" in test:
fp = test["failPoint"]
await self.set_fail_point(
{"configureFailPoint": fp["configureFailPoint"], "mode": "off"}
)
# Assert final state is expected.
outcome = test["outcome"]
expected_c = outcome.get("collection")
if expected_c is not None:
outcome_coll_name = self.get_outcome_coll_name(outcome, collection)
# Read from the primary with local read concern to ensure causal
# consistency.
outcome_coll = async_client_context.client[collection.database.name].get_collection(
outcome_coll_name,
read_preference=ReadPreference.PRIMARY,
read_concern=ReadConcern("local"),
)
actual_data = await outcome_coll.find(sort=[("_id", 1)]).to_list()
# The expected data needs to be the left hand side here otherwise
# CompareType(Binary) doesn't work.
self.assertEqual(wrap_types(expected_c["data"]), actual_data)
def expect_any_error(op):
if isinstance(op, dict):
return op.get("error")
return False
def expect_error_message(expected_result):
if isinstance(expected_result, dict):
return isinstance(expected_result["errorContains"], str)
return False
def expect_error_code(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorCodeName"]
return False
def expect_error_labels_contain(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorLabelsContain"]
return False
def expect_error_labels_omit(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorLabelsOmit"]
return False
def expect_timeout_error(expected_result):
if isinstance(expected_result, dict):
return expected_result["isTimeoutError"]
return False
def expect_error(op):
expected_result = op.get("result")
return (
expect_any_error(op)
or expect_error_message(expected_result)
or expect_error_code(expected_result)
or expect_error_labels_contain(expected_result)
or expect_error_labels_omit(expected_result)
or expect_timeout_error(expected_result)
)
async def end_sessions(sessions):
for s in sessions.values():
# Aborts the transaction if it's open.
await s.end_session()
def decode_raw(val):
"""Decode RawBSONDocuments in the given container."""
if isinstance(val, (list, abc.Mapping)):
return decode(encode({"v": val}))["v"]
return val
TYPES = {
"binData": Binary,
"long": Int64,
"int": int,
"string": str,
"objectId": ObjectId,
"object": dict,
"array": list,
}
def wrap_types(val):
"""Support $$type assertion in command results."""
if isinstance(val, list):
return [wrap_types(v) for v in val]
if isinstance(val, abc.Mapping):
typ = val.get("$$type")
if typ:
if isinstance(typ, str):
types = TYPES[typ]
else:
types = tuple(TYPES[t] for t in typ)
return CompareType(types)
d = {}
for key in val:
d[key] = wrap_types(val[key])
return d
return val

View File

@ -27,7 +27,8 @@
"awaitMinPoolSizeMS": 10000,
"useMultipleMongoses": false,
"observeEvents": [
"commandStartedEvent"
"commandStartedEvent",
"commandFailedEvent"
]
}
},
@ -188,6 +189,11 @@
}
}
},
{
"commandFailedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"commandName": "abortTransaction",
@ -206,6 +212,105 @@
]
}
]
},
{
"description": "withTransaction surfaces a timeout after exhausting transient transaction retries, retaining the last transient error as the timeout cause.",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "failPointClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"insert"
],
"blockConnection": true,
"blockTimeMS": 25,
"errorCode": 24,
"errorLabels": [
"TransientTransactionError"
]
}
}
}
},
{
"name": "withTransaction",
"object": "session",
"arguments": {
"callback": [
{
"name": "insertOne",
"object": "collection",
"arguments": {
"document": {
"_id": 1
},
"session": "session"
},
"expectError": {
"isError": true
}
}
]
},
"expectError": {
"isTimeoutError": true
}
}
],
"expectEvents": [
{
"client": "client",
"ignoreExtraEvents": true,
"events": [
{
"commandStartedEvent": {
"commandName": "insert"
}
},
{
"commandFailedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"commandName": "abortTransaction"
}
},
{
"commandFailedEvent": {
"commandName": "abortTransaction"
}
},
{
"commandStartedEvent": {
"commandName": "insert"
}
},
{
"commandFailedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"commandName": "abortTransaction"
}
},
{
"commandFailedEvent": {
"commandName": "abortTransaction"
}
}
]
}
]
}
]
}

View File

@ -85,7 +85,7 @@
}
},
{
"description": "Mark server unknown on network timeout application error (beforeHandshakeCompletes)",
"description": "Ignore network timeout application error (beforeHandshakeCompletes)",
"applicationErrors": [
{
"address": "a:27017",

View File

@ -0,0 +1,167 @@
{
"description": "Static setVersion (DSC) is compatible with both pre and post DRIVERS-2412",
"uri": "mongodb://a/?replicaSet=rs",
"phases": [
{
"responses": [
[
"a:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": true,
"hosts": [
"a:27017",
"b:27017"
],
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000005"
},
"minWireVersion": 0,
"maxWireVersion": 17
}
],
[
"b:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": false,
"secondary": true,
"hosts": [
"a:27017",
"b:27017"
],
"setName": "rs",
"setVersion": 1,
"minWireVersion": 0,
"maxWireVersion": 17
}
]
],
"outcome": {
"servers": {
"a:27017": {
"type": "RSPrimary",
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000005"
}
},
"b:27017": {
"type": "RSSecondary",
"setName": "rs",
"setVersion": 1,
"electionId": null
}
},
"topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs",
"maxSetVersion": 1,
"maxElectionId": {
"$oid": "000000000000000000000005"
}
}
},
{
"responses": [
[
"b:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": true,
"hosts": [
"a:27017",
"b:27017"
],
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000006"
},
"minWireVersion": 0,
"maxWireVersion": 17
}
]
],
"outcome": {
"servers": {
"a:27017": {
"type": "Unknown",
"setName": null,
"setVersion": null,
"electionId": null
},
"b:27017": {
"type": "RSPrimary",
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000006"
}
}
},
"topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs",
"maxSetVersion": 1,
"maxElectionId": {
"$oid": "000000000000000000000006"
}
}
},
{
"responses": [
[
"a:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": true,
"hosts": [
"a:27017",
"b:27017"
],
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000005"
},
"minWireVersion": 0,
"maxWireVersion": 17
}
]
],
"outcome": {
"servers": {
"a:27017": {
"type": "Unknown",
"setName": null,
"setVersion": null,
"electionId": null
},
"b:27017": {
"type": "RSPrimary",
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000006"
}
}
},
"topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs",
"maxSetVersion": 1,
"maxElectionId": {
"$oid": "000000000000000000000006"
}
}
}
]
}

View File

@ -0,0 +1,227 @@
{
"description": "Member list is updated when setVersion and electionId remain the same",
"uri": "mongodb://a/?replicaSet=rs",
"phases": [
{
"responses": [
[
"a:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": true,
"hosts": [
"a:27017",
"b:27017"
],
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000001"
},
"minWireVersion": 0,
"maxWireVersion": 17
}
],
[
"b:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": false,
"secondary": true,
"hosts": [
"a:27017",
"b:27017"
],
"setName": "rs",
"setVersion": 1,
"minWireVersion": 0,
"maxWireVersion": 17
}
]
],
"outcome": {
"servers": {
"a:27017": {
"type": "RSPrimary",
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000001"
}
},
"b:27017": {
"type": "RSSecondary",
"setName": "rs",
"setVersion": 1,
"electionId": null
}
},
"topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs",
"maxSetVersion": 1,
"maxElectionId": {
"$oid": "000000000000000000000001"
}
}
},
{
"responses": [
[
"a:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": true,
"hosts": [
"a:27017",
"b:27017",
"c:27017"
],
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000001"
},
"minWireVersion": 0,
"maxWireVersion": 17
}
]
],
"outcome": {
"servers": {
"a:27017": {
"type": "RSPrimary",
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000001"
}
},
"b:27017": {
"type": "RSSecondary",
"setName": "rs",
"setVersion": 1,
"electionId": null
},
"c:27017": {
"type": "Unknown",
"setName": null,
"setVersion": null,
"electionId": null
}
},
"topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs",
"maxSetVersion": 1,
"maxElectionId": {
"$oid": "000000000000000000000001"
}
}
},
{
"responses": [
[
"c:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": false,
"secondary": true,
"hosts": [
"a:27017",
"b:27017",
"c:27017"
],
"setName": "rs",
"setVersion": 1,
"minWireVersion": 0,
"maxWireVersion": 17
}
]
],
"outcome": {
"servers": {
"a:27017": {
"type": "RSPrimary",
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000001"
}
},
"b:27017": {
"type": "RSSecondary",
"setName": "rs",
"setVersion": 1,
"electionId": null
},
"c:27017": {
"type": "RSSecondary",
"setName": "rs",
"setVersion": 1,
"electionId": null
}
},
"topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs",
"maxSetVersion": 1,
"maxElectionId": {
"$oid": "000000000000000000000001"
}
}
},
{
"responses": [
[
"a:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": true,
"hosts": [
"a:27017",
"b:27017"
],
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000001"
},
"minWireVersion": 0,
"maxWireVersion": 17
}
]
],
"outcome": {
"servers": {
"a:27017": {
"type": "RSPrimary",
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000001"
}
},
"b:27017": {
"type": "RSSecondary",
"setName": "rs",
"setVersion": 1,
"electionId": null
}
},
"topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs",
"maxSetVersion": 1,
"maxElectionId": {
"$oid": "000000000000000000000001"
}
}
}
]
}

View File

@ -0,0 +1,167 @@
{
"description": "DSC to ASC reverse migration - ASC primary with higher setVersion is accepted",
"uri": "mongodb://a/?replicaSet=rs",
"phases": [
{
"responses": [
[
"a:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": true,
"hosts": [
"a:27017",
"b:27017"
],
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000005"
},
"minWireVersion": 0,
"maxWireVersion": 17
}
],
[
"b:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": false,
"secondary": true,
"hosts": [
"a:27017",
"b:27017"
],
"setName": "rs",
"setVersion": 1,
"minWireVersion": 0,
"maxWireVersion": 17
}
]
],
"outcome": {
"servers": {
"a:27017": {
"type": "RSPrimary",
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000005"
}
},
"b:27017": {
"type": "RSSecondary",
"setName": "rs",
"setVersion": 1,
"electionId": null
}
},
"topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs",
"maxSetVersion": 1,
"maxElectionId": {
"$oid": "000000000000000000000005"
}
}
},
{
"responses": [
[
"b:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": true,
"hosts": [
"a:27017",
"b:27017"
],
"setName": "rs",
"setVersion": 1000,
"electionId": {
"$oid": "000000000000000000000006"
},
"minWireVersion": 0,
"maxWireVersion": 17
}
]
],
"outcome": {
"servers": {
"a:27017": {
"type": "Unknown",
"setName": null,
"setVersion": null,
"electionId": null
},
"b:27017": {
"type": "RSPrimary",
"setName": "rs",
"setVersion": 1000,
"electionId": {
"$oid": "000000000000000000000006"
}
}
},
"topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs",
"maxSetVersion": 1000,
"maxElectionId": {
"$oid": "000000000000000000000006"
}
}
},
{
"responses": [
[
"a:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": true,
"hosts": [
"a:27017",
"b:27017"
],
"setName": "rs",
"setVersion": 1,
"electionId": {
"$oid": "000000000000000000000005"
},
"minWireVersion": 0,
"maxWireVersion": 17
}
]
],
"outcome": {
"servers": {
"a:27017": {
"type": "Unknown",
"setName": null,
"setVersion": null,
"electionId": null
},
"b:27017": {
"type": "RSPrimary",
"setName": "rs",
"setVersion": 1000,
"electionId": {
"$oid": "000000000000000000000006"
}
}
},
"topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs",
"maxSetVersion": 1000,
"maxElectionId": {
"$oid": "000000000000000000000006"
}
}
}
]
}

View File

@ -0,0 +1,119 @@
{
"description": "ASC to DSC forward migration - DSC uses setVersionASC + 1 to prevent false stale detection",
"uri": "mongodb://a/?replicaSet=rs",
"phases": [
{
"responses": [
[
"a:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": true,
"hosts": [
"a:27017",
"b:27017"
],
"setName": "rs",
"setVersion": 10,
"electionId": {
"$oid": "000000000000000000000005"
},
"minWireVersion": 0,
"maxWireVersion": 17
}
],
[
"b:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": false,
"secondary": true,
"hosts": [
"a:27017",
"b:27017"
],
"setName": "rs",
"setVersion": 10,
"minWireVersion": 0,
"maxWireVersion": 17
}
]
],
"outcome": {
"servers": {
"a:27017": {
"type": "RSPrimary",
"setName": "rs",
"setVersion": 10,
"electionId": {
"$oid": "000000000000000000000005"
}
},
"b:27017": {
"type": "RSSecondary",
"setName": "rs",
"setVersion": 10,
"electionId": null
}
},
"topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs",
"maxSetVersion": 10,
"maxElectionId": {
"$oid": "000000000000000000000005"
}
}
},
{
"responses": [
[
"a:27017",
{
"ok": 1,
"helloOk": true,
"isWritablePrimary": true,
"hosts": [
"a:27017",
"b:27017"
],
"setName": "rs",
"setVersion": 11,
"electionId": {
"$oid": "000000000000000000000006"
},
"minWireVersion": 0,
"maxWireVersion": 17
}
]
],
"outcome": {
"servers": {
"a:27017": {
"type": "RSPrimary",
"setName": "rs",
"setVersion": 11,
"electionId": {
"$oid": "000000000000000000000006"
}
},
"b:27017": {
"type": "RSSecondary",
"setName": "rs",
"setVersion": 10,
"electionId": null
}
},
"topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs",
"maxSetVersion": 11,
"maxElectionId": {
"$oid": "000000000000000000000006"
}
}
}
]
}

View File

@ -0,0 +1,62 @@
{
"topology_description": {
"type": "ReplicaSetNoPrimary",
"servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
},
{
"address": "c:27017",
"avg_rtt_ms": 100,
"type": "RSSecondary",
"tags": {
"data_center": "tokyo"
}
}
]
},
"operation": "read",
"read_preference": {
"mode": "Nearest",
"tag_sets": [
{
"data_center": "nyc"
}
]
},
"deprioritized_servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
],
"suitable_servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
],
"in_latency_window": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
]
}

View File

@ -0,0 +1,62 @@
{
"topology_description": {
"type": "ReplicaSetNoPrimary",
"servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
},
{
"address": "c:27017",
"avg_rtt_ms": 100,
"type": "RSSecondary",
"tags": {
"data_center": "tokyo"
}
}
]
},
"operation": "read",
"read_preference": {
"mode": "PrimaryPreferred",
"tag_sets": [
{
"data_center": "nyc"
}
]
},
"deprioritized_servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
],
"suitable_servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
],
"in_latency_window": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
]
}

View File

@ -0,0 +1,62 @@
{
"topology_description": {
"type": "ReplicaSetNoPrimary",
"servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
},
{
"address": "c:27017",
"avg_rtt_ms": 100,
"type": "RSSecondary",
"tags": {
"data_center": "tokyo"
}
}
]
},
"operation": "read",
"read_preference": {
"mode": "Secondary",
"tag_sets": [
{
"data_center": "nyc"
}
]
},
"deprioritized_servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
],
"suitable_servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
],
"in_latency_window": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
]
}

View File

@ -0,0 +1,62 @@
{
"topology_description": {
"type": "ReplicaSetNoPrimary",
"servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
},
{
"address": "c:27017",
"avg_rtt_ms": 100,
"type": "RSSecondary",
"tags": {
"data_center": "tokyo"
}
}
]
},
"operation": "read",
"read_preference": {
"mode": "SecondaryPreferred",
"tag_sets": [
{
"data_center": "nyc"
}
]
},
"deprioritized_servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
],
"suitable_servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
],
"in_latency_window": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
]
}

View File

@ -0,0 +1,70 @@
{
"topology_description": {
"type": "ReplicaSetWithPrimary",
"servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
},
{
"address": "c:27017",
"avg_rtt_ms": 100,
"type": "RSSecondary",
"tags": {
"data_center": "tokyo"
}
},
{
"address": "a:27017",
"avg_rtt_ms": 26,
"type": "RSPrimary",
"tags": {
"data_center": "tokyo"
}
}
]
},
"operation": "read",
"read_preference": {
"mode": "Nearest",
"tag_sets": [
{
"data_center": "nyc"
}
]
},
"deprioritized_servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
],
"suitable_servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
],
"in_latency_window": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
]
}

View File

@ -0,0 +1,70 @@
{
"topology_description": {
"type": "ReplicaSetWithPrimary",
"servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "tokyo"
}
},
{
"address": "c:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "tokyo"
}
},
{
"address": "a:27017",
"avg_rtt_ms": 5,
"type": "RSPrimary",
"tags": {
"data_center": "nyc"
}
}
]
},
"operation": "read",
"read_preference": {
"mode": "PrimaryPreferred",
"tag_sets": [
{
"data_center": "nyc"
}
]
},
"deprioritized_servers": [
{
"address": "a:27017",
"avg_rtt_ms": 5,
"type": "RSPrimary",
"tags": {
"data_center": "nyc"
}
}
],
"suitable_servers": [
{
"address": "a:27017",
"avg_rtt_ms": 5,
"type": "RSPrimary",
"tags": {
"data_center": "nyc"
}
}
],
"in_latency_window": [
{
"address": "a:27017",
"avg_rtt_ms": 5,
"type": "RSPrimary",
"tags": {
"data_center": "nyc"
}
}
]
}

View File

@ -0,0 +1,70 @@
{
"topology_description": {
"type": "ReplicaSetWithPrimary",
"servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
},
{
"address": "c:27017",
"avg_rtt_ms": 100,
"type": "RSSecondary",
"tags": {
"data_center": "tokyo"
}
},
{
"address": "a:27017",
"avg_rtt_ms": 26,
"type": "RSPrimary",
"tags": {
"data_center": "tokyo"
}
}
]
},
"operation": "read",
"read_preference": {
"mode": "Secondary",
"tag_sets": [
{
"data_center": "nyc"
}
]
},
"deprioritized_servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
],
"suitable_servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
],
"in_latency_window": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
]
}

View File

@ -0,0 +1,70 @@
{
"topology_description": {
"type": "ReplicaSetWithPrimary",
"servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
},
{
"address": "c:27017",
"avg_rtt_ms": 100,
"type": "RSSecondary",
"tags": {
"data_center": "tokyo"
}
},
{
"address": "a:27017",
"avg_rtt_ms": 5,
"type": "RSPrimary",
"tags": {
"data_center": "tokyo"
}
}
]
},
"operation": "read",
"read_preference": {
"mode": "SecondaryPreferred",
"tag_sets": [
{
"data_center": "nyc"
}
]
},
"deprioritized_servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary",
"tags": {
"data_center": "nyc"
}
}
],
"suitable_servers": [
{
"address": "a:27017",
"avg_rtt_ms": 5,
"type": "RSPrimary",
"tags": {
"data_center": "tokyo"
}
}
],
"in_latency_window": [
{
"address": "a:27017",
"avg_rtt_ms": 5,
"type": "RSPrimary",
"tags": {
"data_center": "tokyo"
}
}
]
}

View File

@ -0,0 +1,41 @@
{
"topology_description": {
"type": "ReplicaSetWithPrimary",
"servers": [
{
"address": "a:27017",
"avg_rtt_ms": 5,
"type": "RSPrimary"
},
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary"
}
]
},
"operation": "read",
"read_preference": {
"mode": "SecondaryPreferred",
"tag_sets": [
{
"data_center": "nyc"
},
{}
]
},
"suitable_servers": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary"
}
],
"in_latency_window": [
{
"address": "b:27017",
"avg_rtt_ms": 5,
"type": "RSSecondary"
}
]
}

View File

@ -189,6 +189,52 @@ class TestSession(IntegrationTest):
f"{f.__name__} did not return implicit session to pool",
)
# Explicit bound session
for f, args, kw in ops:
with client.start_session() as s:
with s.bind():
listener.reset()
s._materialize()
last_use = s._server_session.last_use
start = time.monotonic()
self.assertLessEqual(last_use, start)
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
f(*args, **kw)
self.assertGreaterEqual(len(listener.started_events), 1)
for event in listener.started_events:
self.assertIn(
"lsid",
event.command,
f"{f.__name__} sent no lsid with {event.command_name}",
)
self.assertEqual(
s.session_id,
event.command["lsid"],
f"{f.__name__} sent wrong lsid with {event.command_name}",
)
self.assertFalse(s.has_ended)
self.assertTrue(s.has_ended)
with self.assertRaisesRegex(InvalidOperation, "ended session"):
with s.bind():
f(*args, **kw)
# Test a session cannot be used on another client.
with self.client2.start_session() as s:
with s.bind():
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
with self.assertRaisesRegex(
InvalidOperation,
"Only the client that created the bound session can perform operations within its context block",
):
f(*args, **kw)
def test_implicit_sessions_checkout(self):
# "To confirm that implicit sessions only allocate their server session after a
# successful connection checkout" test from Driver Sessions Spec.
@ -825,6 +871,73 @@ class TestSession(IntegrationTest):
with client.start_session() as s:
self.assertRaises(TypeError, lambda: copy.copy(s))
def test_nested_session_binding(self):
coll = self.client.pymongo_test.test
coll.insert_one({"x": 1})
session1 = self.client.start_session()
session2 = self.client.start_session()
session1._materialize()
session2._materialize()
try:
self.listener.reset()
# Uses implicit session
coll.find_one()
implicit_lsid = self.listener.started_events[0].command.get("lsid")
self.assertIsNotNone(implicit_lsid)
self.assertNotEqual(implicit_lsid, session1.session_id)
self.assertNotEqual(implicit_lsid, session2.session_id)
with session1.bind(end_session=False):
self.listener.reset()
# Uses bound session1
coll.find_one()
session1_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session1_lsid, session1.session_id)
with session2.bind(end_session=False):
self.listener.reset()
# Uses bound session2
coll.find_one()
session2_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session2_lsid, session2.session_id)
self.assertNotEqual(session2_lsid, session1.session_id)
self.listener.reset()
# Use bound session1 again
coll.find_one()
session1_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session1_lsid, session1.session_id)
self.assertNotEqual(session1_lsid, session2.session_id)
self.listener.reset()
# Uses implicit session
coll.find_one()
implicit_lsid = self.listener.started_events[0].command.get("lsid")
self.assertIsNotNone(implicit_lsid)
self.assertNotEqual(implicit_lsid, session1.session_id)
self.assertNotEqual(implicit_lsid, session2.session_id)
finally:
session1.end_session()
session2.end_session()
def test_session_binding_end_session(self):
coll = self.client.pymongo_test.test
coll.insert_one({"x": 1})
with self.client.start_session().bind() as s1:
coll.find_one()
self.assertTrue(s1.has_ended)
with self.client.start_session().bind(end_session=False) as s2:
coll.find_one()
self.assertFalse(s2.has_ended)
s2.end_session()
class TestCausalConsistency(UnitTest):
listener: SessionTestListener

View File

@ -145,13 +145,11 @@ class TestSON(unittest.TestCase):
self.assertEqual(ele * 100, test_son[ele])
def test_contains_has(self):
"""has_key and __contains__"""
"""Test key membership via 'in' and __contains__."""
test_son = SON([(1, 100), (2, 200), (3, 300)])
self.assertIn(1, test_son)
self.assertIn(2, test_son, "in failed")
self.assertNotIn(22, test_son, "in succeeded when it shouldn't")
self.assertTrue(test_son.has_key(2), "has_key failed")
self.assertFalse(test_son.has_key(22), "has_key succeeded when it shouldn't")
def test_clears(self):
"""Test clear()"""

View File

@ -1451,11 +1451,6 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
self.assertListEqual(sorted_expected_documents, actual_documents)
def run_scenario(self, spec, uri=None):
# Kill all sessions before and after each test to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
self.kill_all_sessions()
# Handle flaky tests.
flaky_tests = [
("PYTHON-5170", ".*test_discovery_and_monitoring.*"),
@ -1491,6 +1486,15 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
if skip_reason is not None:
raise unittest.SkipTest(f"{skip_reason}")
# Kill all sessions after each test with transactions to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
for op in spec["operations"]:
name = op["name"]
if name == "startTransaction" or name == "withTransaction":
self.addCleanup(self.kill_all_sessions)
break
# process createEntities
self._uri = uri
self.entity_map = EntityMapUtil(self)

View File

@ -16,43 +16,13 @@
from __future__ import annotations
import asyncio
import functools
import os
import time
import unittest
from collections import abc
from inspect import iscoroutinefunction
from test import IntegrationTest, client_context, client_knobs
from test import client_context
from test.helpers import ConcurrentRunner
from test.utils_shared import (
CMAPListener,
CompareType,
EventListener,
OvertCommandListener,
ScenarioDict,
ServerAndTopologyEventListener,
camel_to_snake,
camel_to_snake_args,
parse_spec_options,
prepare_spec_arguments,
)
from typing import List
from test.utils_shared import ScenarioDict
from bson import ObjectId, decode, encode, json_util
from bson.binary import Binary
from bson.int64 import Int64
from bson.son import SON
from gridfs import GridFSBucket
from gridfs.synchronous.grid_file import GridFSBucket
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
from bson import json_util
from pymongo.lock import _cond_wait, _create_condition, _create_lock
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult, _WriteResult
from pymongo.synchronous import client_session
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.cursor import Cursor
from pymongo.write_concern import WriteConcern
_IS_SYNC = True
@ -219,595 +189,3 @@ class SpecTestCreator:
self._create_tests()
else:
asyncio.run(self._create_tests())
class SpecRunner(IntegrationTest):
mongos_clients: List
knobs: client_knobs
listener: EventListener
def setUp(self) -> None:
super().setUp()
self.mongos_clients = []
# Speed up the tests by decreasing the heartbeat frequency.
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
self.knobs.enable()
self.targets = {}
self.listener = None # type: ignore
self.pool_listener = None
self.server_listener = None
self.maxDiff = None
def tearDown(self) -> None:
self.knobs.disable()
def set_fail_point(self, command_args):
clients = self.mongos_clients if self.mongos_clients else [self.client]
for client in clients:
self.configure_fail_point(client, command_args)
def targeted_fail_point(self, session, fail_point):
"""Run the targetedFailPoint test operation.
Enable the fail point on the session's pinned mongos.
"""
clients = {c.address: c for c in self.mongos_clients}
client = clients[session._pinned_address]
self.configure_fail_point(client, fail_point)
self.addCleanup(self.set_fail_point, {"mode": "off"})
def assert_session_pinned(self, session):
"""Run the assertSessionPinned test operation.
Assert that the given session is pinned.
"""
self.assertIsNotNone(session._transaction.pinned_address)
def assert_session_unpinned(self, session):
"""Run the assertSessionUnpinned test operation.
Assert that the given session is not pinned.
"""
self.assertIsNone(session._pinned_address)
self.assertIsNone(session._transaction.pinned_address)
def assert_collection_exists(self, database, collection):
"""Run the assertCollectionExists test operation."""
db = self.client[database]
self.assertIn(collection, db.list_collection_names())
def assert_collection_not_exists(self, database, collection):
"""Run the assertCollectionNotExists test operation."""
db = self.client[database]
self.assertNotIn(collection, db.list_collection_names())
def assert_index_exists(self, database, collection, index):
"""Run the assertIndexExists test operation."""
coll = self.client[database][collection]
self.assertIn(index, [doc["name"] for doc in coll.list_indexes()])
def assert_index_not_exists(self, database, collection, index):
"""Run the assertIndexNotExists test operation."""
coll = self.client[database][collection]
self.assertNotIn(index, [doc["name"] for doc in coll.list_indexes()])
def wait(self, ms):
"""Run the "wait" test operation."""
time.sleep(ms / 1000.0)
def assertErrorLabelsContain(self, exc, expected_labels):
labels = [l for l in expected_labels if exc.has_error_label(l)]
self.assertEqual(labels, expected_labels)
def assertErrorLabelsOmit(self, exc, omit_labels):
for label in omit_labels:
self.assertFalse(
exc.has_error_label(label), msg=f"error labels should not contain {label}"
)
def kill_all_sessions(self):
clients = self.mongos_clients if self.mongos_clients else [self.client]
for client in clients:
try:
client.admin.command("killAllSessions", [])
except (OperationFailure, AutoReconnect):
# "operation was interrupted" by killing the command's
# own session.
# On 8.0+ killAllSessions sometimes returns a network error.
pass
def check_command_result(self, expected_result, result):
# Only compare the keys in the expected result.
filtered_result = {}
for key in expected_result:
try:
filtered_result[key] = result[key]
except KeyError:
pass
self.assertEqual(filtered_result, expected_result)
# TODO: factor the following function with test_crud.py.
def check_result(self, expected_result, result):
if isinstance(result, _WriteResult):
for res in expected_result:
prop = camel_to_snake(res)
# SPEC-869: Only BulkWriteResult has upserted_count.
if prop == "upserted_count" and not isinstance(result, BulkWriteResult):
if result.upserted_id is not None:
upserted_count = 1
else:
upserted_count = 0
self.assertEqual(upserted_count, expected_result[res], prop)
elif prop == "inserted_ids":
# BulkWriteResult does not have inserted_ids.
if isinstance(result, BulkWriteResult):
self.assertEqual(len(expected_result[res]), result.inserted_count)
else:
# InsertManyResult may be compared to [id1] from the
# crud spec or {"0": id1} from the retryable write spec.
ids = expected_result[res]
if isinstance(ids, dict):
ids = [ids[str(i)] for i in range(len(ids))]
self.assertEqual(ids, result.inserted_ids, prop)
elif prop == "upserted_ids":
# Convert indexes from strings to integers.
ids = expected_result[res]
expected_ids = {}
for str_index in ids:
expected_ids[int(str_index)] = ids[str_index]
self.assertEqual(expected_ids, result.upserted_ids, prop)
else:
self.assertEqual(getattr(result, prop), expected_result[res], prop)
return True
else:
def _helper(expected_result, result):
if isinstance(expected_result, abc.Mapping):
for i in expected_result.keys():
self.assertEqual(expected_result[i], result[i])
elif isinstance(expected_result, list):
for i, k in zip(expected_result, result):
_helper(i, k)
else:
self.assertEqual(expected_result, result)
_helper(expected_result, result)
return None
def get_object_name(self, op):
"""Allow subclasses to override handling of 'object'
Transaction spec says 'object' is required.
"""
return op["object"]
@staticmethod
def parse_options(opts):
return parse_spec_options(opts)
def run_operation(self, sessions, collection, operation):
original_collection = collection
name = camel_to_snake(operation["name"])
if name == "run_command":
name = "command"
elif name == "download_by_name":
name = "open_download_stream_by_name"
elif name == "download":
name = "open_download_stream"
elif name == "map_reduce":
self.skipTest("PyMongo does not support mapReduce")
elif name == "count":
self.skipTest("PyMongo does not support count")
database = collection.database
collection = database.get_collection(collection.name)
if "collectionOptions" in operation:
collection = collection.with_options(
**self.parse_options(operation["collectionOptions"])
)
object_name = self.get_object_name(operation)
if object_name == "gridfsbucket":
# Only create the GridFSBucket when we need it (for the gridfs
# retryable reads tests).
obj = GridFSBucket(database, bucket_name=collection.name)
else:
objects = {
"client": database.client,
"database": database,
"collection": collection,
"testRunner": self,
}
objects.update(sessions)
obj = objects[object_name]
# Combine arguments with options and handle special cases.
arguments = operation.get("arguments", {})
arguments.update(arguments.pop("options", {}))
self.parse_options(arguments)
cmd = getattr(obj, name)
with_txn_callback = functools.partial(
self.run_operations, sessions, original_collection, in_with_transaction=True
)
prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback)
if name == "run_on_thread":
args = {"sessions": sessions, "collection": collection}
args.update(arguments)
arguments = args
if not _IS_SYNC and iscoroutinefunction(cmd):
result = cmd(**dict(arguments))
else:
result = cmd(**dict(arguments))
# Cleanup open change stream cursors.
if name == "watch":
self.addCleanup(result.close)
if name == "aggregate":
if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]:
# Read from the primary to ensure causal consistency.
out = collection.database.get_collection(
arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY
)
return out.find()
if "download" in name:
result = Binary(result.read())
if isinstance(result, Cursor) or isinstance(result, CommandCursor):
return result.to_list()
return result
def allowable_errors(self, op):
"""Allow encryption spec to override expected error classes."""
return (PyMongoError,)
def _run_op(self, sessions, collection, op, in_with_transaction):
expected_result = op.get("result")
if expect_error(op):
with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context:
self.run_operation(sessions, collection, op.copy())
exc = context.exception
if expect_error_message(expected_result):
if isinstance(exc, BulkWriteError):
errmsg = str(exc.details).lower()
else:
errmsg = str(exc).lower()
self.assertIn(expected_result["errorContains"].lower(), errmsg)
if expect_error_code(expected_result):
self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName"))
if expect_error_labels_contain(expected_result):
self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"])
if expect_error_labels_omit(expected_result):
self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"])
if expect_timeout_error(expected_result):
self.assertIsInstance(exc, PyMongoError)
if not exc.timeout:
# Re-raise the exception for better diagnostics.
raise exc
# Reraise the exception if we're in the with_transaction
# callback.
if in_with_transaction:
raise context.exception
else:
result = self.run_operation(sessions, collection, op.copy())
if "result" in op:
if op["name"] == "runCommand":
self.check_command_result(expected_result, result)
else:
self.check_result(expected_result, result)
def run_operations(self, sessions, collection, ops, in_with_transaction=False):
for op in ops:
self._run_op(sessions, collection, op, in_with_transaction)
# TODO: factor with test_command_monitoring.py
def check_events(self, test, listener, session_ids):
events = listener.started_events
if not len(test["expectations"]):
return
# Give a nicer message when there are missing or extra events
cmds = decode_raw([event.command for event in events])
self.assertEqual(len(events), len(test["expectations"]), cmds)
for i, expectation in enumerate(test["expectations"]):
event_type = next(iter(expectation))
event = events[i]
# The tests substitute 42 for any number other than 0.
if event.command_name == "getMore" and event.command["getMore"]:
event.command["getMore"] = Int64(42)
elif event.command_name == "killCursors":
event.command["cursors"] = [Int64(42)]
elif event.command_name == "update":
# TODO: remove this once PYTHON-1744 is done.
# Add upsert and multi fields back into expectations.
updates = expectation[event_type]["command"]["updates"]
for update in updates:
update.setdefault("upsert", False)
update.setdefault("multi", False)
# Replace afterClusterTime: 42 with actual afterClusterTime.
expected_cmd = expectation[event_type]["command"]
expected_read_concern = expected_cmd.get("readConcern")
if expected_read_concern is not None:
time = expected_read_concern.get("afterClusterTime")
if time == 42:
actual_time = event.command.get("readConcern", {}).get("afterClusterTime")
if actual_time is not None:
expected_read_concern["afterClusterTime"] = actual_time
recovery_token = expected_cmd.get("recoveryToken")
if recovery_token == 42:
expected_cmd["recoveryToken"] = CompareType(dict)
# Replace lsid with a name like "session0" to match test.
if "lsid" in event.command:
for name, lsid in session_ids.items():
if event.command["lsid"] == lsid:
event.command["lsid"] = name
break
for attr, expected in expectation[event_type].items():
actual = getattr(event, attr)
expected = wrap_types(expected)
if isinstance(expected, dict):
for key, val in expected.items():
if val is None:
if key in actual:
self.fail(f"Unexpected key [{key}] in {actual!r}")
elif key not in actual:
self.fail(f"Expected key [{key}] in {actual!r}")
else:
self.assertEqual(
val, decode_raw(actual[key]), f"Key [{key}] in {actual}"
)
else:
self.assertEqual(actual, expected)
def maybe_skip_scenario(self, test):
if test.get("skipReason"):
self.skipTest(test.get("skipReason"))
def get_scenario_db_name(self, scenario_def):
"""Allow subclasses to override a test's database name."""
return scenario_def["database_name"]
def get_scenario_coll_name(self, scenario_def):
"""Allow subclasses to override a test's collection name."""
return scenario_def["collection_name"]
def get_outcome_coll_name(self, outcome, collection):
"""Allow subclasses to override outcome collection."""
return collection.name
def run_test_ops(self, sessions, collection, test):
"""Added to allow retryable writes spec to override a test's
operation.
"""
self.run_operations(sessions, collection, test["operations"])
def parse_client_options(self, opts):
"""Allow encryption spec to override a clientOptions parsing."""
return opts
def setup_scenario(self, scenario_def):
"""Allow specs to override a test's setup."""
db_name = self.get_scenario_db_name(scenario_def)
coll_name = self.get_scenario_coll_name(scenario_def)
documents = scenario_def["data"]
# Setup the collection with as few majority writes as possible.
db = client_context.client.get_database(db_name)
coll_exists = bool(db.list_collection_names(filter={"name": coll_name}))
if coll_exists:
db[coll_name].delete_many({})
# Only use majority wc only on the final write.
wc = WriteConcern(w="majority")
if documents:
db.get_collection(coll_name, write_concern=wc).insert_many(documents)
elif not coll_exists:
# Ensure collection exists.
db.create_collection(coll_name, write_concern=wc)
def run_scenario(self, scenario_def, test):
self.maybe_skip_scenario(test)
# Kill all sessions before and after each test to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
self.kill_all_sessions()
self.addCleanup(self.kill_all_sessions)
self.setup_scenario(scenario_def)
database_name = self.get_scenario_db_name(scenario_def)
collection_name = self.get_scenario_coll_name(scenario_def)
# SPEC-1245 workaround StaleDbVersion on distinct
for c in self.mongos_clients:
c[database_name][collection_name].distinct("x")
# Configure the fail point before creating the client.
if "failPoint" in test:
fp = test["failPoint"]
self.set_fail_point(fp)
self.addCleanup(
self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"}
)
listener = OvertCommandListener()
pool_listener = CMAPListener()
server_listener = ServerAndTopologyEventListener()
# Create a new client, to avoid interference from pooled sessions.
client_options = self.parse_client_options(test["clientOptions"])
use_multi_mongos = test["useMultipleMongoses"]
host = None
if use_multi_mongos:
if client_context.load_balancer:
host = client_context.MULTI_MONGOS_LB_URI
elif client_context.is_mongos:
host = client_context.mongos_seeds()
client = self.rs_client(
h=host, event_listeners=[listener, pool_listener, server_listener], **client_options
)
self.scenario_client = client
self.listener = listener
self.pool_listener = pool_listener
self.server_listener = server_listener
# Create session0 and session1.
sessions = {}
session_ids = {}
for i in range(2):
# Don't attempt to create sessions if they are not supported by
# the running server version.
if not client_context.sessions_enabled:
break
session_name = "session%d" % i
opts = camel_to_snake_args(test["sessionOptions"][session_name])
if "default_transaction_options" in opts:
txn_opts = self.parse_options(opts["default_transaction_options"])
txn_opts = client_session.TransactionOptions(**txn_opts)
opts["default_transaction_options"] = txn_opts
s = client.start_session(**dict(opts))
sessions[session_name] = s
# Store lsid so we can access it after end_session, in check_events.
session_ids[session_name] = s.session_id
self.addCleanup(end_sessions, sessions)
collection = client[database_name][collection_name]
self.run_test_ops(sessions, collection, test)
end_sessions(sessions)
self.check_events(test, listener, session_ids)
# Disable fail points.
if "failPoint" in test:
fp = test["failPoint"]
self.set_fail_point({"configureFailPoint": fp["configureFailPoint"], "mode": "off"})
# Assert final state is expected.
outcome = test["outcome"]
expected_c = outcome.get("collection")
if expected_c is not None:
outcome_coll_name = self.get_outcome_coll_name(outcome, collection)
# Read from the primary with local read concern to ensure causal
# consistency.
outcome_coll = client_context.client[collection.database.name].get_collection(
outcome_coll_name,
read_preference=ReadPreference.PRIMARY,
read_concern=ReadConcern("local"),
)
actual_data = outcome_coll.find(sort=[("_id", 1)]).to_list()
# The expected data needs to be the left hand side here otherwise
# CompareType(Binary) doesn't work.
self.assertEqual(wrap_types(expected_c["data"]), actual_data)
def expect_any_error(op):
if isinstance(op, dict):
return op.get("error")
return False
def expect_error_message(expected_result):
if isinstance(expected_result, dict):
return isinstance(expected_result["errorContains"], str)
return False
def expect_error_code(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorCodeName"]
return False
def expect_error_labels_contain(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorLabelsContain"]
return False
def expect_error_labels_omit(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorLabelsOmit"]
return False
def expect_timeout_error(expected_result):
if isinstance(expected_result, dict):
return expected_result["isTimeoutError"]
return False
def expect_error(op):
expected_result = op.get("result")
return (
expect_any_error(op)
or expect_error_message(expected_result)
or expect_error_code(expected_result)
or expect_error_labels_contain(expected_result)
or expect_error_labels_omit(expected_result)
or expect_timeout_error(expected_result)
)
def end_sessions(sessions):
for s in sessions.values():
# Aborts the transaction if it's open.
s.end_session()
def decode_raw(val):
"""Decode RawBSONDocuments in the given container."""
if isinstance(val, (list, abc.Mapping)):
return decode(encode({"v": val}))["v"]
return val
TYPES = {
"binData": Binary,
"long": Int64,
"int": int,
"string": str,
"objectId": ObjectId,
"object": dict,
"array": list,
}
def wrap_types(val):
"""Support $$type assertion in command results."""
if isinstance(val, list):
return [wrap_types(v) for v in val]
if isinstance(val, abc.Mapping):
typ = val.get("$$type")
if typ:
if isinstance(typ, str):
types = TYPES[typ]
else:
types = tuple(TYPES[t] for t in typ)
return CompareType(types)
d = {}
for key in val:
d[key] = wrap_types(val[key])
return d
return val

View File

@ -37,6 +37,7 @@ replacements = {
"AsyncRawBatchCursor": "RawBatchCursor",
"AsyncRawBatchCommandCursor": "RawBatchCommandCursor",
"AsyncClientSession": "ClientSession",
"_AsyncBoundSessionContext": "_BoundSessionContext",
"AsyncChangeStream": "ChangeStream",
"AsyncCollectionChangeStream": "CollectionChangeStream",
"AsyncDatabaseChangeStream": "DatabaseChangeStream",