Merge upstream/backpressure

Resolve merge conflicts by taking upstream versions, which represent the
refined evolution of local prototype backpressure work.

Co-Authored-By: Claude Code <noreply@anthropic.com>
This commit is contained in:
Jib 2026-03-23 11:20:10 -04:00
commit 7acc87ad65
423 changed files with 90669 additions and 9134 deletions

4
.codecov.yml Normal file
View File

@ -0,0 +1,4 @@
# do not notify until at least 100 builds have been uploaded from the CI pipeline
# you can also set after_n_builds on comments independently
comment:
after_n_builds: 100

View File

@ -5,19 +5,12 @@
set -eu
. .evergreen/utils.sh
# Set up the virtual env.
. .evergreen/scripts/setup-dev-env.sh
uv sync --group coverage
source .venv/bin/activate
if [ -z "${PYTHON_BINARY:-}" ]; then
PYTHON_BINARY=$(find_python3)
fi
createvirtualenv "$PYTHON_BINARY" covenv
# Keep in sync with run-tests.sh
# coverage >=5 is needed for relative_files=true.
pip install -q "coverage[toml]>=5,<=7.5"
pip list
ls -la coverage/
python -m coverage combine coverage/coverage.*
python -m coverage html -d htmlcov
coverage combine coverage/coverage.*
coverage html -d htmlcov

View File

@ -38,6 +38,7 @@ post:
# Disabled, causing timeouts
# - func: "upload working dir"
- func: "teardown system"
- func: "upload codecov"
- func: "upload coverage"
- func: "upload mo artifacts"
- func: "upload test results"

View File

@ -101,8 +101,8 @@ functions:
- AUTH
- SSL
- ORCHESTRATION_FILE
- PYTHON_BINARY
- PYTHON_VERSION
- UV_PYTHON
- TOOLCHAIN_VERSION
- STORAGE_ENGINE
- REQUIRE_API_VERSION
- DRIVERS_TOOLS
@ -134,10 +134,10 @@ functions:
- AWS_SECRET_ACCESS_KEY
- AWS_SESSION_TOKEN
- COVERAGE
- PYTHON_BINARY
- UV_PYTHON
- LIBMONGOCRYPT_URL
- MONGODB_URI
- PYTHON_VERSION
- TOOLCHAIN_VERSION
- DISABLE_TEST_COMMANDS
- GREEN_FRAMEWORK
- NO_EXT
@ -151,6 +151,7 @@ functions:
- VERSION
- IS_WIN32
- REQUIRE_FIPS
- TEST_MIN_DEPS
type: test
- command: subprocess.exec
params:
@ -238,6 +239,40 @@ functions:
working_dir: src
type: test
# Test numpy
test numpy:
- command: subprocess.exec
params:
binary: bash
args:
- .evergreen/just.sh
- test-numpy
working_dir: src
include_expansions_in_env:
- TOOLCHAIN_VERSION
- COVERAGE
type: test
# Upload coverage codecov
upload codecov:
- command: subprocess.exec
params:
binary: bash
args:
- .evergreen/scripts/upload-codecov.sh
working_dir: src
include_expansions_in_env:
- CODECOV_TOKEN
- build_variant
- task_name
- github_commit
- github_pr_number
- github_pr_head_branch
- github_author
- requester
- branch_name
type: test
# Upload coverage
upload coverage:
- command: ec2.assume_role

File diff suppressed because it is too large Load Diff

View File

@ -1,27 +1,17 @@
buildvariants:
# Alternative hosts tests
- name: openssl-1.0.2-rhel7-v5.0-python3.9
tasks:
- name: .test-no-toolchain
display_name: OpenSSL 1.0.2 RHEL7 v5.0 Python3.9
run_on:
- rhel79-small
batchtime: 10080
expansions:
VERSION: "5.0"
PYTHON_VERSION: "3.9"
PYTHON_BINARY: /opt/python/3.9/bin/python3
- name: other-hosts-rhel9-fips-latest
tasks:
- name: .test-no-toolchain
display_name: Other hosts RHEL9-FIPS latest
run_on:
- rhel92-fips
batchtime: 10080
batchtime: 1440
expansions:
VERSION: latest
NO_EXT: "1"
REQUIRE_FIPS: "1"
UV_PYTHON: /usr/bin/python3.11
tags: []
- name: other-hosts-rhel8-zseries-latest
tasks:
@ -29,7 +19,7 @@ buildvariants:
display_name: Other hosts RHEL8-zseries latest
run_on:
- rhel8-zseries-small
batchtime: 10080
batchtime: 1440
expansions:
VERSION: latest
NO_EXT: "1"
@ -40,7 +30,7 @@ buildvariants:
display_name: Other hosts RHEL8-POWER8 latest
run_on:
- rhel8-power-small
batchtime: 10080
batchtime: 1440
expansions:
VERSION: latest
NO_EXT: "1"
@ -51,7 +41,7 @@ buildvariants:
display_name: Other hosts RHEL8-arm64 latest
run_on:
- rhel82-arm64-small
batchtime: 10080
batchtime: 1440
expansions:
VERSION: latest
NO_EXT: "1"
@ -62,7 +52,7 @@ buildvariants:
display_name: Other hosts Amazon2023 latest
run_on:
- amazon2023-arm64-latest-large-m8g
batchtime: 10080
batchtime: 1440
expansions:
VERSION: latest
NO_EXT: "1"
@ -79,39 +69,35 @@ buildvariants:
TEST_NAME: atlas_connect
tags: [pr]
# Atlas data lake tests
- name: atlas-data-lake-ubuntu-22
tasks:
- name: .test-no-orchestration
display_name: Atlas Data Lake Ubuntu-22
run_on:
- ubuntu2204-small
expansions:
TEST_NAME: data_lake
tags: [pr]
# Aws auth tests
- name: auth-aws-ubuntu-20
- name: auth-aws-rhel8
tasks:
- name: .auth-aws
display_name: Auth AWS Ubuntu-20
display_name: Auth AWS RHEL8
run_on:
- ubuntu2004-small
- rhel87-small
tags: []
- name: auth-aws-win64
tasks:
- name: .auth-aws !.auth-aws-ecs
- name: .auth-aws
display_name: Auth AWS Win64
run_on:
- windows-64-vsMulti-small
- windows-2022-latest-small
tags: []
- name: auth-aws-macos
tasks:
- name: .auth-aws !.auth-aws-web-identity !.auth-aws-ecs !.auth-aws-ec2
- name: .auth-aws !.auth-aws-web-identity !.auth-aws-ec2
display_name: Auth AWS macOS
run_on:
- macos-14
tags: [pr]
- name: auth-aws-ecs-macos
tasks:
- name: .auth-aws-ecs
display_name: Auth AWS ECS macOS
run_on:
- ubuntu2404-small
tags: [pr]
# Aws lambda tests
- name: faas-lambda
@ -154,6 +140,15 @@ buildvariants:
- rhel87-small
expansions:
COMPRESSOR: zstd
- name: compression-zstd-ubuntu-22
tasks:
- name: .test-standard !.server-4.2 !.server-4.4 !.server-5.0 .python-3.14
- name: .test-standard !.server-4.2 !.server-4.4 !.server-5.0 .python-3.14t
display_name: Compression zstd Ubuntu-22
run_on:
- ubuntu2204-small
expansions:
COMPRESSOR: ztsd
# Coverage report tests
- name: coverage-report
@ -164,17 +159,16 @@ buildvariants:
- rhel87-small
# Disable test commands tests
- name: disable-test-commands-rhel8-python3.9
- name: disable-test-commands-rhel8
tasks:
- name: .test-standard .server-latest
display_name: Disable test commands RHEL8 Python3.9
display_name: Disable test commands RHEL8
run_on:
- rhel87-small
expansions:
AUTH: auth
SSL: ssl
DISABLE_TEST_COMMANDS: "1"
PYTHON_BINARY: /opt/python/3.9/bin/python3
# Doctests tests
- name: doctests-rhel8
@ -193,7 +187,7 @@ buildvariants:
display_name: Encryption RHEL8
run_on:
- rhel87-small
batchtime: 10080
batchtime: 1440
expansions:
TEST_NAME: encryption
tags: [encryption_tag]
@ -203,7 +197,7 @@ buildvariants:
display_name: Encryption macOS
run_on:
- macos-14
batchtime: 10080
batchtime: 1440
expansions:
TEST_NAME: encryption
tags: [encryption_tag]
@ -212,8 +206,8 @@ buildvariants:
- name: .test-non-standard !.pypy
display_name: Encryption Win64
run_on:
- windows-64-vsMulti-small
batchtime: 10080
- windows-2022-latest-small
batchtime: 1440
expansions:
TEST_NAME: encryption
tags: [encryption_tag]
@ -223,7 +217,7 @@ buildvariants:
display_name: Encryption crypt_shared RHEL8
run_on:
- rhel87-small
batchtime: 10080
batchtime: 1440
expansions:
TEST_NAME: encryption
TEST_CRYPT_SHARED: "true"
@ -234,7 +228,7 @@ buildvariants:
display_name: Encryption crypt_shared macOS
run_on:
- macos-14
batchtime: 10080
batchtime: 1440
expansions:
TEST_NAME: encryption
TEST_CRYPT_SHARED: "true"
@ -244,8 +238,8 @@ buildvariants:
- name: .test-non-standard !.pypy
display_name: Encryption crypt_shared Win64
run_on:
- windows-64-vsMulti-small
batchtime: 10080
- windows-2022-latest-small
batchtime: 1440
expansions:
TEST_NAME: encryption
TEST_CRYPT_SHARED: "true"
@ -256,7 +250,7 @@ buildvariants:
display_name: Encryption PyOpenSSL RHEL8
run_on:
- rhel87-small
batchtime: 10080
batchtime: 1440
expansions:
TEST_NAME: encryption
SUB_TEST_NAME: pyopenssl
@ -265,7 +259,7 @@ buildvariants:
# Enterprise auth tests
- name: auth-enterprise-rhel8
tasks:
- name: .test-non-standard .auth
- name: .test-standard-auth .auth !.free-threaded
display_name: Auth Enterprise RHEL8
run_on:
- rhel87-small
@ -274,7 +268,7 @@ buildvariants:
AUTH: auth
- name: auth-enterprise-macos
tasks:
- name: .test-non-standard !.pypy .auth
- name: .test-standard-auth !.pypy .auth !.free-threaded
display_name: Auth Enterprise macOS
run_on:
- macos-14
@ -283,64 +277,18 @@ buildvariants:
AUTH: auth
- name: auth-enterprise-win64
tasks:
- name: .test-non-standard !.pypy .auth
- name: .test-standard-auth !.pypy .auth !.free-threaded
display_name: Auth Enterprise Win64
run_on:
- windows-64-vsMulti-small
- windows-2022-latest-small
expansions:
TEST_NAME: enterprise_auth
AUTH: auth
# Free threaded tests
- name: free-threaded-rhel8-python3.14t
tasks:
- name: .free-threading
display_name: Free-threaded RHEL8 Python3.14t
run_on:
- rhel87-small
expansions:
PYTHON_BINARY: /opt/python/3.14t/bin/python3
tags: [pr]
- name: free-threaded-macos-python3.14t
tasks:
- name: .free-threading
display_name: Free-threaded macOS Python3.14t
run_on:
- macos-14
expansions:
PYTHON_BINARY: /Library/Frameworks/PythonT.Framework/Versions/3.14/bin/python3t
tags: []
- name: free-threaded-macos-arm64-python3.14t
tasks:
- name: .free-threading
display_name: Free-threaded macOS Arm64 Python3.14t
run_on:
- macos-14-arm64
expansions:
PYTHON_BINARY: /Library/Frameworks/PythonT.Framework/Versions/3.14/bin/python3t
tags: []
- name: free-threaded-win64-python3.14t
tasks:
- name: .free-threading
display_name: Free-threaded Win64 Python3.14t
run_on:
- windows-64-vsMulti-small
expansions:
PYTHON_BINARY: C:/python/Python314/python3.14t.exe
tags: []
# Green framework tests
- name: green-eventlet-rhel8
tasks:
- name: .test-standard .standalone-noauth-nossl .python-3.9 .sync
display_name: Green Eventlet RHEL8
run_on:
- rhel87-small
expansions:
GREEN_FRAMEWORK: eventlet
- name: green-gevent-rhel8
tasks:
- name: .test-standard .standalone-noauth-nossl .sync
- name: .test-standard .sync !.free-threaded
display_name: Green Gevent RHEL8
run_on:
- rhel87-small
@ -359,10 +307,10 @@ buildvariants:
- name: kms
tasks:
- name: test-gcpkms
batchtime: 10080
batchtime: 1440
- name: test-gcpkms-fail
- name: test-azurekms
batchtime: 10080
batchtime: 1440
- name: test-azurekms-fail
display_name: KMS
run_on:
@ -379,10 +327,18 @@ buildvariants:
display_name: Load Balancer
run_on:
- rhel87-small
batchtime: 10080
batchtime: 1440
expansions:
TEST_NAME: load_balancer
# Min support tests
- name: min-support-rhel8
tasks:
- name: .test-min-support
display_name: Min Support RHEL8
run_on:
- rhel87-small
# Mockupdb tests
- name: mockupdb-rhel8
tasks:
@ -411,6 +367,8 @@ buildvariants:
display_name: No C Ext RHEL8
run_on:
- rhel87-small
expansions:
NO_EXT: "1"
# No server tests
- name: no-server-rhel8
@ -435,7 +393,7 @@ buildvariants:
- name: .ocsp-rsa !.ocsp-staple .4.4
display_name: OCSP Win64
run_on:
- windows-64-vsMulti-small
- windows-2022-latest-small
batchtime: 10080
- name: ocsp-macos
tasks:
@ -453,14 +411,16 @@ buildvariants:
display_name: Auth OIDC Ubuntu-22
run_on:
- ubuntu2204-small
batchtime: 10080
batchtime: 1440
- name: auth-oidc-local-ubuntu-22
tasks:
- name: "!.auth_oidc_remote .auth_oidc"
display_name: Auth OIDC Local Ubuntu-22
run_on:
- ubuntu2204-small
batchtime: 10080
batchtime: 1440
expansions:
COVERAGE: "1"
tags: [pr]
- name: auth-oidc-macos
tasks:
@ -468,14 +428,14 @@ buildvariants:
display_name: Auth OIDC macOS
run_on:
- macos-14
batchtime: 10080
batchtime: 1440
- name: auth-oidc-win64
tasks:
- name: "!.auth_oidc_remote .auth_oidc"
display_name: Auth OIDC Win64
run_on:
- windows-64-vsMulti-small
batchtime: 10080
- windows-2022-latest-small
batchtime: 1440
# Perf tests
- name: performance-benchmarks
@ -484,7 +444,7 @@ buildvariants:
display_name: Performance Benchmarks
run_on:
- rhel90-dbx-perf-large
batchtime: 10080
batchtime: 1440
# Pyopenssl tests
- name: pyopenssl-rhel8
@ -494,7 +454,7 @@ buildvariants:
display_name: PyOpenSSL RHEL8
run_on:
- rhel87-small
batchtime: 10080
batchtime: 1440
expansions:
SUB_TEST_NAME: pyopenssl
- name: pyopenssl-macos
@ -504,7 +464,7 @@ buildvariants:
display_name: PyOpenSSL macOS
run_on:
- rhel87-small
batchtime: 10080
batchtime: 1440
expansions:
SUB_TEST_NAME: pyopenssl
- name: pyopenssl-win64
@ -513,20 +473,18 @@ buildvariants:
- name: .test-standard !.pypy .async .replica_set-noauth-ssl
display_name: PyOpenSSL Win64
run_on:
- rhel87-small
batchtime: 10080
- windows-2022-latest-small
batchtime: 1440
expansions:
SUB_TEST_NAME: pyopenssl
# Search index tests
- name: search-index-helpers-rhel8-python3.9
- name: search-index-helpers-rhel8
tasks:
- name: .search_index
display_name: Search Index Helpers RHEL8 Python3.9
display_name: Search Index Helpers RHEL8
run_on:
- rhel87-small
expansions:
PYTHON_BINARY: /opt/python/3.9/bin/python3
# Server version tests
- name: mongodb-v4.2
@ -659,7 +617,7 @@ buildvariants:
- name: .test-standard !.pypy
display_name: "* Test Win64"
run_on:
- windows-64-vsMulti-small
- windows-2022-latest-small
tags: [standard-non-linux]
- name: test-win32
tasks:
@ -680,3 +638,42 @@ buildvariants:
- rhel87-small
expansions:
STORAGE_ENGINE: inmemory
# Test numpy tests
- name: test-numpy-rhel8
tasks:
- name: .test-numpy
display_name: Test Numpy RHEL8
run_on:
- rhel87-small
tags: [binary, vector, pr]
- name: test-numpy-macos
tasks:
- name: .test-numpy
display_name: Test Numpy macOS
run_on:
- macos-14
tags: [binary, vector]
- name: test-numpy-macos-arm64
tasks:
- name: .test-numpy
display_name: Test Numpy macOS Arm64
run_on:
- macos-14-arm64
tags: [binary, vector]
- name: test-numpy-win64
tasks:
- name: .test-numpy
display_name: Test Numpy Win64
run_on:
- windows-2022-latest-small
tags: [binary, vector]
- name: test-numpy-win32
tasks:
- name: .test-numpy
display_name: Test Numpy Win32
run_on:
- windows-64-vsMulti-small
expansions:
IS_WIN32: "1"
tags: [binary, vector]

View File

@ -3,14 +3,10 @@ PYMONGO=$(dirname "$(cd "$(dirname "$0")" || exit; pwd)")
rm $PYMONGO/test/transactions/legacy/errors-client.json # PYTHON-1894
rm $PYMONGO/test/connection_monitoring/wait-queue-fairness.json # PYTHON-1873
rm $PYMONGO/test/client-side-encryption/spec/unified/fle2v2-BypassQueryAnalysis.json # PYTHON-5143
rm $PYMONGO/test/client-side-encryption/spec/unified/fle2v2-EncryptedFields-vs-EncryptedFieldsMap.json # PYTHON-5143
rm $PYMONGO/test/client-side-encryption/spec/unified/localSchema.json # PYTHON-5143
rm $PYMONGO/test/client-side-encryption/spec/unified/maxWireVersion.json # PYTHON-5143
rm $PYMONGO/test/unified-test-format/valid-pass/poc-queryable-encryption.json # PYTHON-5143
rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-application-error.json # PYTHON-4918
rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-checkout-error.json # PYTHON-4918
rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-min-pool-size-error.json # PYTHON-4918
rm $PYMONGO/test/client-side-encryption/spec/unified/client-bulkWrite-qe.json # PYTHON-4929
# Python doesn't implement DRIVERS-3064
rm $PYMONGO/test/collection_management/listCollections-rawdata.json
@ -45,7 +41,7 @@ rm $PYMONGO/test/index_management/index-rawdata.json
rm $PYMONGO/test/collection_management/modifyCollection-*.json
# PYTHON-5248 - Remove support for MongoDB 4.0
rm $PYMONGO/test/**/pre-42-*.json
find /$PYMONGO/test -type f -name 'pre-42-*.json' -delete
# PYTHON-3359 - Remove Database and Collection level timeout override
rm $PYMONGO/test/csot/override-collection-timeoutMS.json
@ -54,4 +50,7 @@ rm $PYMONGO/test/csot/override-database-timeoutMS.json
# PYTHON-2943 - Socks5 Proxy Support
rm $PYMONGO/test/uri_options/proxy-options.json
# PYTHON-5517 - Avoid clearing the connection pool when the server connection rate limiter triggers
rm $PYMONGO/test/discovery_and_monitoring/unified/backpressure-*.json
echo "Done removing unimplemented tests"

View File

@ -76,9 +76,6 @@ do
auth)
cpjson auth/tests/ auth
;;
atlas-data-lake-testing|data_lake)
cpjson atlas-data-lake-testing/tests/ data_lake
;;
bson-binary-vector|bson_binary_vector)
cpjson bson-binary-vector/tests/ bson_binary_vector
;;
@ -97,6 +94,9 @@ do
change-streams|change_streams)
cpjson change-streams/tests/ change_streams/
;;
client-backpressure|client_backpressure)
cpjson client-backpressure/tests client-backpressure
;;
client-side-encryption|csfle|fle)
cpjson client-side-encryption/tests/ client-side-encryption/spec
cpjson client-side-encryption/corpus/ client-side-encryption/corpus

View File

@ -19,15 +19,14 @@ fi
# Now we can safely enable xtrace
set -o xtrace
# Install python with pip.
PYTHON_VER="python3.9"
# Install a c compiler.
apt-get -qq update < /dev/null > /dev/null
apt-get -qq install $PYTHON_VER $PYTHON_VER-venv build-essential $PYTHON_VER-dev -y < /dev/null > /dev/null
apt-get -q install -y build-essential
export PYTHON_BINARY=$PYTHON_VER
export SET_XTRACE_ON=1
cd src
rm -rf .venv
rm -f .evergreen/scripts/test-env.sh || true
rm -f .evergreen/scripts/env.sh || true
bash ./.evergreen/just.sh setup-tests auth_aws ecs-remote
bash .evergreen/just.sh run-tests

View File

@ -8,7 +8,7 @@ if [ ${OIDC_ENV} == "k8s" ]; then
SUB_TEST_NAME=$K8S_VARIANT-remote
else
SUB_TEST_NAME=$OIDC_ENV-remote
apt-get install -y python3-dev build-essential
sudo apt-get install -y python3-dev build-essential
fi
bash ./.evergreen/just.sh setup-tests auth_oidc $SUB_TEST_NAME

View File

@ -6,7 +6,8 @@ SCRIPT_DIR=$(dirname ${BASH_SOURCE:-$0})
SCRIPT_DIR="$( cd -- "$SCRIPT_DIR" > /dev/null 2>&1 && pwd )"
ROOT_DIR="$(dirname $SCRIPT_DIR)"
pushd $ROOT_DIR
PREV_DIR=$(pwd)
cd $ROOT_DIR
# Try to source the env file.
if [ -f $SCRIPT_DIR/scripts/env.sh ]; then
@ -25,11 +26,20 @@ else
exit 1
fi
# List the packages.
uv sync ${UV_ARGS} --reinstall --quiet
uv pip list
cleanup_tests() {
# Avoid leaving the lock file in a changed state when we change the resolution type.
if [ -n "${TEST_MIN_DEPS:-}" ]; then
git checkout uv.lock || true
fi
cd $PREV_DIR
}
trap "cleanup_tests" SIGINT ERR
# Start the test runner.
uv run ${UV_ARGS} .evergreen/scripts/run_tests.py "$@"
echo "Running tests with UV_PYTHON=${UV_PYTHON:-}..."
echo "UV_ARGS=${UV_ARGS}"
uv run ${UV_ARGS} --reinstall-package pymongo .evergreen/scripts/run_tests.py "$@"
echo "Running tests with UV_PYTHON=${UV_PYTHON:-}... done."
popd
cleanup_tests

View File

@ -11,11 +11,10 @@ pushd $HERE/../.. >/dev/null
BASE_SHA="$1"
HEAD_SHA="$2"
. .evergreen/utils.sh
if [ -z "${PYTHON_BINARY:-}" ]; then
PYTHON_BINARY=$(find_python3)
fi
# Set up the virtual env.
. $HERE/setup-dev-env.sh
uv venv --seed
source .venv/bin/activate
# Use the previous commit if this was not a PR run.
if [ "$BASE_SHA" == "$HEAD_SHA" ]; then
@ -24,7 +23,6 @@ fi
function get_import_time() {
local log_file
createvirtualenv "$PYTHON_BINARY" import-venv
python -m pip install -q ".[aws,encryption,gssapi,ocsp,snappy,zstd]"
# Import once to cache modules
python -c "import pymongo"

View File

@ -7,6 +7,7 @@ from itertools import product
from generate_config_utils import (
ALL_PYTHONS,
ALL_VERSIONS,
BATCHTIME_DAY,
BATCHTIME_WEEK,
C_EXTS,
CPYTHONS,
@ -108,25 +109,10 @@ def create_standard_nonlinux_variants() -> list[BuildVariant]:
return variants
def create_free_threaded_variants() -> list[BuildVariant]:
variants = []
for host_name in ("rhel8", "macos", "macos-arm64", "win64"):
python = "3.14t"
tasks = [".free-threading"]
tags = []
if host_name == "rhel8":
tags.append("pr")
host = HOSTS[host_name]
display_name = get_variant_name("Free-threaded", host, python=python)
variant = create_variant(tasks, display_name, tags=tags, python=python, host=host)
variants.append(variant)
return variants
def create_encryption_variants() -> list[BuildVariant]:
variants = []
tags = ["encryption_tag"]
batchtime = BATCHTIME_WEEK
batchtime = BATCHTIME_DAY
def get_encryption_expansions(encryption):
expansions = dict(TEST_NAME="encryption")
@ -183,7 +169,7 @@ def create_load_balancer_variants():
tasks,
"Load Balancer",
host=DEFAULT_HOST,
batchtime=BATCHTIME_WEEK,
batchtime=BATCHTIME_DAY,
expansions=expansions,
)
]
@ -208,6 +194,22 @@ def create_compression_variants():
expansions=expansions,
)
)
# Add explicit tests with compression.zstd support on linux.
host = HOSTS["ubuntu22"]
expansions = dict(COMPRESSOR="ztsd")
tasks = [
".test-standard !.server-4.2 !.server-4.4 !.server-5.0 .python-3.14",
".test-standard !.server-4.2 !.server-4.4 !.server-5.0 .python-3.14t",
]
display_name = get_variant_name(f"Compression {compressor}", host)
variants.append(
create_variant(
tasks,
display_name,
host=host,
expansions=expansions,
)
)
return variants
@ -216,9 +218,13 @@ def create_enterprise_auth_variants():
for host in ["rhel8", "macos", "win64"]:
expansions = dict(TEST_NAME="enterprise_auth", AUTH="auth")
display_name = get_variant_name("Auth Enterprise", host)
tasks = [".test-non-standard .auth"]
if host != "rhel8":
tasks = [".test-non-standard !.pypy .auth"]
tasks = [".test-standard-auth .auth !.free-threaded"]
# https://jira.mongodb.org/browse/PYTHON-5586
if host == "macos":
tasks = [".test-standard-auth !.pypy .auth !.free-threaded"]
if host == "win64":
# https://jira.mongodb.org/browse/PYTHON-5704
tasks = [".test-standard-auth !.pypy .auth !.free-threaded"]
variant = create_variant(tasks, display_name, host=host, expansions=expansions)
variants.append(variant)
return variants
@ -226,7 +232,7 @@ def create_enterprise_auth_variants():
def create_pyopenssl_variants():
base_name = "PyOpenSSL"
batchtime = BATCHTIME_WEEK
batchtime = BATCHTIME_DAY
expansions = dict(SUB_TEST_NAME="pyopenssl")
variants = []
@ -300,12 +306,8 @@ def create_stable_api_variants():
def create_green_framework_variants():
variants = []
host = DEFAULT_HOST
for framework in ["eventlet", "gevent"]:
tasks = [".test-standard .standalone-noauth-nossl .sync"]
if framework == "eventlet":
# Eventlet has issues with dnspython > 2.0 and newer versions of CPython
# https://jira.mongodb.org/browse/PYTHON-5284
tasks = [".test-standard .standalone-noauth-nossl .python-3.9 .sync"]
for framework in ["gevent"]:
tasks = [".test-standard .sync !.free-threaded"]
expansions = dict(GREEN_FRAMEWORK=framework)
display_name = get_variant_name(f"Green {framework.capitalize()}", host)
variant = create_variant(tasks, display_name, host=host, expansions=expansions)
@ -319,15 +321,7 @@ def create_no_c_ext_variants():
expansions = dict()
handle_c_ext(C_EXTS[0], expansions)
display_name = get_variant_name("No C Ext", host)
return [create_variant(tasks, display_name, host=host)]
def create_atlas_data_lake_variants():
host = HOSTS["ubuntu22"]
tasks = [".test-no-orchestration"]
expansions = dict(TEST_NAME="data_lake")
display_name = get_variant_name("Atlas Data Lake", host)
return [create_variant(tasks, display_name, tags=["pr"], host=host, expansions=expansions)]
return [create_variant(tasks, display_name, host=host, expansions=expansions)]
def create_mod_wsgi_variants():
@ -341,10 +335,44 @@ def create_mod_wsgi_variants():
def create_disable_test_commands_variants():
host = DEFAULT_HOST
expansions = dict(AUTH="auth", SSL="ssl", DISABLE_TEST_COMMANDS="1")
python = CPYTHONS[0]
display_name = get_variant_name("Disable test commands", host, python=python)
display_name = get_variant_name("Disable test commands", host)
tasks = [".test-standard .server-latest"]
return [create_variant(tasks, display_name, host=host, python=python, expansions=expansions)]
return [create_variant(tasks, display_name, host=host, expansions=expansions)]
def create_test_numpy_tasks():
tasks = []
for python in MIN_MAX_PYTHON:
tags = ["binary", "vector", f"python-{python}", "test-numpy"]
vars = dict(TOOLCHAIN_VERSION=python)
if python == MIN_MAX_PYTHON[-1]:
tags.append("pr")
vars["COVERAGE"] = "1"
task_name = get_task_name("test-numpy", python=python, **vars)
test_func = FunctionCall(func="test numpy", vars=vars)
tasks.append(EvgTask(name=task_name, tags=tags, commands=[test_func]))
return tasks
def create_test_numpy_variants() -> list[BuildVariant]:
variants = []
base_display_name = "Test Numpy"
# Test a subset on each of the other platforms.
for host_name in ("rhel8", "macos", "macos-arm64", "win64", "win32"):
tasks = [".test-numpy"]
host = HOSTS[host_name]
tags = ["binary", "vector"]
if host_name == "rhel8":
tags.append("pr")
expansions = dict()
if host_name == "win32":
expansions["IS_WIN32"] = "1"
display_name = get_variant_name(base_display_name, host)
variant = create_variant(tasks, display_name, host=host, tags=tags, expansions=expansions)
variants.append(variant)
return variants
def create_oidc_auth_variants():
@ -360,7 +388,7 @@ def create_oidc_auth_variants():
tasks,
get_variant_name("Auth OIDC", host),
host=host,
batchtime=BATCHTIME_WEEK,
batchtime=BATCHTIME_DAY,
)
)
# Add a specific local test to run on PRs.
@ -372,7 +400,8 @@ def create_oidc_auth_variants():
get_variant_name("Auth OIDC Local", host),
tags=["pr"],
host=host,
batchtime=BATCHTIME_WEEK,
batchtime=BATCHTIME_DAY,
expansions=dict(COVERAGE="1"),
)
)
return variants
@ -380,12 +409,10 @@ def create_oidc_auth_variants():
def create_search_index_variants():
host = DEFAULT_HOST
python = CPYTHONS[0]
return [
create_variant(
[".search_index"],
get_variant_name("Search Index Helpers", host, python=python),
python=python,
get_variant_name("Search Index Helpers", host),
host=host,
)
]
@ -437,9 +464,9 @@ def create_coverage_report_variants():
def create_kms_variants():
tasks = []
tasks.append(EvgTaskRef(name="test-gcpkms", batchtime=BATCHTIME_WEEK))
tasks.append(EvgTaskRef(name="test-gcpkms", batchtime=BATCHTIME_DAY))
tasks.append("test-gcpkms-fail")
tasks.append(EvgTaskRef(name="test-azurekms", batchtime=BATCHTIME_WEEK))
tasks.append(EvgTaskRef(name="test-azurekms", batchtime=BATCHTIME_DAY))
tasks.append("test-azurekms-fail")
return [create_variant(tasks, "KMS", host=HOSTS["debian11"])]
@ -454,23 +481,21 @@ def create_backport_pr_variants():
def create_perf_variants():
host = HOSTS["perf"]
return [
create_variant([".perf"], "Performance Benchmarks", host=host, batchtime=BATCHTIME_WEEK)
]
return [create_variant([".perf"], "Performance Benchmarks", host=host, batchtime=BATCHTIME_DAY)]
def create_aws_auth_variants():
variants = []
for host_name in ["ubuntu20", "win64", "macos"]:
for host_name in ["rhel8", "win64", "macos"]:
expansions = dict()
tasks = [".auth-aws"]
tags = []
if host_name == "macos":
tasks = [".auth-aws !.auth-aws-web-identity !.auth-aws-ecs !.auth-aws-ec2"]
tasks = [".auth-aws !.auth-aws-web-identity !.auth-aws-ec2"]
tags = ["pr"]
elif host_name == "win64":
tasks = [".auth-aws !.auth-aws-ecs"]
tasks = [".auth-aws"]
host = HOSTS[host_name]
variant = create_variant(
tasks,
@ -480,9 +505,25 @@ def create_aws_auth_variants():
expansions=expansions,
)
variants.append(variant)
# The ECS test must be run on Ubuntu 24 to match the Fargate Config.
variant = create_variant(
[".auth-aws-ecs"],
get_variant_name("Auth AWS ECS", host),
host=HOSTS["ubuntu24"],
tags=tags,
expansions=expansions,
)
variants.append(variant)
return variants
def create_min_support_variants():
host = HOSTS["rhel8"]
name = get_variant_name("Min Support", host=host)
return [create_variant([".test-min-support"], name, host=host)]
def create_no_server_variants():
host = HOSTS["rhel8"]
name = get_variant_name("No server", host=host)
@ -490,22 +531,9 @@ def create_no_server_variants():
def create_alternative_hosts_variants():
batchtime = BATCHTIME_WEEK
batchtime = BATCHTIME_DAY
variants = []
host = HOSTS["rhel7"]
version = "5.0"
variants.append(
create_variant(
[".test-no-toolchain"],
get_variant_name("OpenSSL 1.0.2", host, python=CPYTHONS[0], version=version),
host=host,
python=CPYTHONS[0],
batchtime=batchtime,
expansions=dict(VERSION=version, PYTHON_VERSION=CPYTHONS[0]),
)
)
version = "latest"
for host_name in OTHER_HOSTS:
expansions = dict(VERSION="latest")
@ -514,6 +542,8 @@ def create_alternative_hosts_variants():
tags = []
if "fips" in host_name.lower():
expansions["REQUIRE_FIPS"] = "1"
# Use explicit Python 3.11 binary on the host since the default python3 is 3.9.
expansions["UV_PYTHON"] = "/usr/bin/python3.11"
if "amazon" in host_name.lower():
tags.append("pr")
variants.append(
@ -541,22 +571,20 @@ def create_aws_lambda_variants():
def create_server_version_tasks():
tasks = []
task_inputs = []
task_combos = set()
# All combinations of topology, auth, ssl, and sync should be tested.
for (topology, auth, ssl, sync), python in zip_cycle(
list(product(TOPOLOGIES, ["auth", "noauth"], ["ssl", "nossl"], SYNCS)), ALL_PYTHONS
):
task_inputs.append((topology, auth, ssl, sync, python))
task_combos.add((topology, auth, ssl, sync, python))
# Every python should be tested with sharded cluster, auth, ssl, with sync and async.
for python, sync in product(ALL_PYTHONS, SYNCS):
task_input = ("sharded_cluster", "auth", "ssl", sync, python)
if task_input not in task_inputs:
task_inputs.append(task_input)
task_combos.add(("sharded_cluster", "auth", "ssl", sync, python))
# Assemble the tasks.
seen = set()
for topology, auth, ssl, sync, python in task_inputs:
for topology, auth, ssl, sync, python in sorted(task_combos):
combo = f"{topology}-{auth}-{ssl}"
tags = ["server-version", f"python-{python}", combo, sync]
if combo in [
@ -569,12 +597,21 @@ def create_server_version_tasks():
seen.add(combo)
tags.append("pr")
expansions = dict(AUTH=auth, SSL=ssl, TOPOLOGY=topology)
if python not in PYPYS:
if python == ALL_PYTHONS[0]:
expansions["TEST_MIN_DEPS"] = "1"
if "t" in python:
tags.append("free-threaded")
if "pr" in tags:
expansions["COVERAGE"] = "1"
name = get_task_name("test-server-version", python=python, sync=sync, **expansions)
name = get_task_name(
"test-server-version",
python=python,
sync=sync,
**expansions,
)
server_func = FunctionCall(func="run server", vars=expansions)
test_vars = expansions.copy()
test_vars["PYTHON_VERSION"] = python
test_vars["TOOLCHAIN_VERSION"] = python
test_vars["TEST_NAME"] = f"default_{sync}"
test_func = FunctionCall(func="run tests", vars=test_vars)
tasks.append(EvgTask(name=name, tags=tags, commands=[server_func, test_func]))
@ -603,15 +640,15 @@ def create_no_toolchain_tasks():
def create_test_non_standard_tasks():
"""For variants that set a TEST_NAME."""
tasks = []
task_combos = []
task_combos = set()
# For each version and topology, rotate through the CPythons.
for (version, topology), python in zip_cycle(list(product(ALL_VERSIONS, TOPOLOGIES)), CPYTHONS):
pr = version == "latest"
task_combos.append((version, topology, python, pr))
# For each PyPy and topology, rotate through the the versions.
task_combos.add((version, topology, python, pr))
# For each PyPy and topology, rotate through the MongoDB versions.
for (python, topology), version in zip_cycle(list(product(PYPYS, TOPOLOGIES)), ALL_VERSIONS):
task_combos.append((version, topology, python, False))
for version, topology, python, pr in task_combos:
task_combos.add((version, topology, python, False))
for version, topology, python, pr in sorted(task_combos):
auth, ssl = get_standard_auth_ssl(topology)
tags = [
"test-non-standard",
@ -620,15 +657,65 @@ def create_test_non_standard_tasks():
f"{topology}-{auth}-{ssl}",
auth,
]
if "t" in python:
tags.append("free-threaded")
if python in PYPYS:
tags.append("pypy")
if pr:
tags.append("pr")
expansions = dict(AUTH=auth, SSL=ssl, TOPOLOGY=topology, VERSION=version)
if python == ALL_PYTHONS[0]:
expansions["TEST_MIN_DEPS"] = "1"
elif pr:
expansions["COVERAGE"] = "1"
name = get_task_name("test-non-standard", python=python, **expansions)
server_func = FunctionCall(func="run server", vars=expansions)
test_vars = expansions.copy()
test_vars["PYTHON_VERSION"] = python
test_vars["TOOLCHAIN_VERSION"] = python
test_func = FunctionCall(func="run tests", vars=test_vars)
tasks.append(EvgTask(name=name, tags=tags, commands=[server_func, test_func]))
return tasks
def create_test_standard_auth_tasks():
"""We only use auth on sharded clusters"""
tasks = []
task_combos = set()
# Rotate through the CPython and MongoDB versions
for (version, topology), python in zip_cycle(
list(product(ALL_VERSIONS, ["sharded_cluster"])), CPYTHONS
):
pr = version == "latest"
task_combos.add((version, topology, python, pr))
# Rotate through each PyPy and MongoDB versions.
for (python, topology), version in zip_cycle(
list(product(PYPYS, ["sharded_cluster"])), ALL_VERSIONS
):
task_combos.add((version, topology, python, False))
for version, topology, python, pr in sorted(task_combos):
auth, ssl = get_standard_auth_ssl(topology)
tags = [
"test-standard-auth",
f"server-{version}",
f"python-{python}",
f"{topology}-{auth}-{ssl}",
auth,
]
if "t" in python:
tags.append("free-threaded")
if python in PYPYS:
tags.append("pypy")
if pr:
tags.append("pr")
expansions = dict(AUTH=auth, SSL=ssl, TOPOLOGY=topology, VERSION=version)
if python == ALL_PYTHONS[0]:
expansions["TEST_MIN_DEPS"] = "1"
elif pr:
expansions["COVERAGE"] = "1"
name = get_task_name("test-standard-auth", python=python, **expansions)
server_func = FunctionCall(func="run server", vars=expansions)
test_vars = expansions.copy()
test_vars["TOOLCHAIN_VERSION"] = python
test_func = FunctionCall(func="run tests", vars=test_vars)
tasks.append(EvgTask(name=name, tags=tags, commands=[server_func, test_func]))
return tasks
@ -654,15 +741,21 @@ def create_standard_tasks():
f"{topology}-{auth}-{ssl}",
sync,
]
if "t" in python:
tags.append("free-threaded")
if python in PYPYS:
tags.append("pypy")
if pr:
tags.append("pr")
expansions = dict(AUTH=auth, SSL=ssl, TOPOLOGY=topology, VERSION=version)
if python == ALL_PYTHONS[0]:
expansions["TEST_MIN_DEPS"] = "1"
elif pr:
expansions["COVERAGE"] = "1"
name = get_task_name("test-standard", python=python, sync=sync, **expansions)
server_func = FunctionCall(func="run server", vars=expansions)
test_vars = expansions.copy()
test_vars["PYTHON_VERSION"] = python
test_vars["TOOLCHAIN_VERSION"] = python
test_vars["TEST_NAME"] = f"default_{sync}"
test_func = FunctionCall(func="run tests", vars=test_vars)
tasks.append(EvgTask(name=name, tags=tags, commands=[server_func, test_func]))
@ -676,9 +769,11 @@ def create_no_orchestration_tasks():
"test-no-orchestration",
f"python-{python}",
]
name = get_task_name("test-no-orchestration", python=python)
assume_func = FunctionCall(func="assume ec2 role")
test_vars = dict(PYTHON_VERSION=python)
test_vars = dict(TOOLCHAIN_VERSION=python)
if python == ALL_PYTHONS[0]:
test_vars["TEST_MIN_DEPS"] = "1"
name = get_task_name("test-no-orchestration", **test_vars)
test_func = FunctionCall(func="run tests", vars=test_vars)
commands = [assume_func, test_func]
tasks.append(EvgTask(name=name, tags=tags, commands=commands))
@ -715,17 +810,23 @@ def create_aws_tasks():
"env-creds",
"session-creds",
"web-identity",
"ecs",
]
assume_func = FunctionCall(func="assume ec2 role")
for version, test_type, python in zip_cycle(get_versions_from("4.4"), aws_test_types, CPYTHONS):
base_name = f"test-auth-aws-{version}"
base_tags = ["auth-aws"]
server_vars = dict(AUTH_AWS="1", VERSION=version)
server_func = FunctionCall(func="run server", vars=server_vars)
assume_func = FunctionCall(func="assume ec2 role")
tags = [*base_tags, f"auth-aws-{test_type}"]
name = get_task_name(f"{base_name}-{test_type}", python=python)
test_vars = dict(TEST_NAME="auth_aws", SUB_TEST_NAME=test_type, PYTHON_VERSION=python)
if "t" in python:
tags.append("free-threaded")
test_vars = dict(TEST_NAME="auth_aws", SUB_TEST_NAME=test_type, TOOLCHAIN_VERSION=python)
if python == MIN_MAX_PYTHON[0]:
test_vars["TEST_MIN_DEPS"] = "1"
elif python == MIN_MAX_PYTHON[-1]:
tags.append("pr")
test_vars["COVERAGE"] = "1"
name = get_task_name(f"{base_name}-{test_type}", **test_vars)
test_func = FunctionCall(func="run tests", vars=test_vars)
funcs = [server_func, assume_func, test_func]
tasks.append(EvgTask(name=name, tags=tags, commands=funcs))
@ -737,12 +838,24 @@ def create_aws_tasks():
TEST_NAME="auth_aws",
SUB_TEST_NAME="web-identity",
AWS_ROLE_SESSION_NAME="test",
PYTHON_VERSION=python,
TOOLCHAIN_VERSION=python,
)
if "t" in python:
tags.append("free-threaded")
test_func = FunctionCall(func="run tests", vars=test_vars)
funcs = [server_func, assume_func, test_func]
tasks.append(EvgTask(name=name, tags=tags, commands=funcs))
# Add the ECS task. This will run on Ubuntu 24 to match the
# Fargate environment.
tags = ["auth-aws-ecs"]
test_vars = dict(TEST_NAME="auth_aws", SUB_TEST_NAME="ecs")
name = get_task_name("test-auth-aws-ecs", **test_vars)
test_func = FunctionCall(func="run tests", vars=test_vars)
server_func = FunctionCall(func="run server", vars=dict(VERSION="8.0"))
funcs = [assume_func, server_func, test_func]
tasks.append(EvgTask(name=name, tags=tags, commands=funcs))
return tasks
@ -750,11 +863,11 @@ def create_oidc_tasks():
tasks = []
for sub_test in ["default", "azure", "gcp", "eks", "aks", "gke"]:
vars = dict(TEST_NAME="auth_oidc", SUB_TEST_NAME=sub_test)
test_func = FunctionCall(func="run tests", vars=vars)
task_name = f"test-auth-oidc-{sub_test}"
tags = ["auth_oidc"]
if sub_test != "default":
tags.append("auth_oidc_remote")
test_func = FunctionCall(func="run tests", vars=vars)
task_name = get_task_name(f"test-auth-oidc-{sub_test}", **vars)
tasks.append(EvgTask(name=task_name, tags=tags, commands=[test_func]))
return tasks
@ -765,15 +878,19 @@ def create_mod_wsgi_tasks():
for (test, topology), python in zip_cycle(
product(["standalone", "embedded-mode"], ["standalone", "replica_set"]), CPYTHONS
):
if "t" in python:
continue
if test == "standalone":
task_name = "mod-wsgi-"
else:
task_name = "mod-wsgi-embedded-mode-"
task_name += topology.replace("_", "-")
task_name = get_task_name(task_name, python=python)
server_vars = dict(TOPOLOGY=topology, PYTHON_VERSION=python)
server_vars = dict(TOPOLOGY=topology, TOOLCHAIN_VERSION=python)
server_func = FunctionCall(func="run server", vars=server_vars)
vars = dict(TEST_NAME="mod_wsgi", SUB_TEST_NAME=test.split("-")[0], PYTHON_VERSION=python)
vars = dict(
TEST_NAME="mod_wsgi", SUB_TEST_NAME=test.split("-")[0], TOOLCHAIN_VERSION=python
)
test_func = FunctionCall(func="run tests", vars=vars)
tags = ["mod_wsgi", "pr"]
commands = [server_func, test_func]
@ -795,21 +912,40 @@ def _create_ocsp_tasks(algo, variant, server_type, base_task_name):
ORCHESTRATION_FILE=file_name,
OCSP_SERVER_TYPE=server_type,
TEST_NAME="ocsp",
PYTHON_VERSION=python,
TOOLCHAIN_VERSION=python,
VERSION=version,
)
test_func = FunctionCall(func="run tests", vars=vars)
if python == ALL_PYTHONS[0]:
vars["TEST_MIN_DEPS"] = "1"
tags = ["ocsp", f"ocsp-{algo}", version]
if "disableStapling" not in variant:
tags.append("ocsp-staple")
if algo == "valid-cert-server-staples" and version == "latest":
if base_task_name == "valid-cert-server-staples" and version == "latest":
tags.append("pr")
task_name = get_task_name(
f"test-ocsp-{algo}-{base_task_name}", python=python, version=version
)
if "TEST_MIN_DEPS" not in vars:
vars["COVERAGE"] = "1"
test_func = FunctionCall(func="run tests", vars=vars)
task_name = get_task_name(f"test-ocsp-{algo}-{base_task_name}", **vars)
tasks.append(EvgTask(name=task_name, tags=tags, commands=[test_func]))
return tasks
def create_min_support_tasks():
server_func = FunctionCall(func="run server")
from generate_config_utils import MIN_SUPPORT_VERSIONS
tasks = []
for python, topology in product(MIN_SUPPORT_VERSIONS, TOPOLOGIES):
auth, ssl = get_standard_auth_ssl(topology)
vars = dict(UV_PYTHON=python, AUTH=auth, SSL=ssl, TOPOLOGY=topology)
test_func = FunctionCall(func="run tests", vars=vars)
task_name = get_task_name(
"test-min-support", python=python, topology=topology, auth=auth, ssl=ssl
)
tags = ["test-min-support"]
commands = [server_func, test_func]
tasks.append(EvgTask(name=task_name, tags=tags, commands=commands))
return tasks
@ -826,7 +962,7 @@ def create_aws_lambda_tasks():
def create_search_index_tasks():
assume_func = FunctionCall(func="assume ec2 role")
server_func = FunctionCall(func="run server", vars=dict(TEST_NAME="search_index"))
vars = dict(TEST_NAME="search_index")
vars = dict(TEST_NAME="search_index", TOOLCHAIN_VERSION=CPYTHONS[0])
test_func = FunctionCall(func="run tests", vars=vars)
task_name = "test-search-index-helpers"
tags = ["search_index"]
@ -935,15 +1071,6 @@ def create_ocsp_tasks():
return tasks
def create_free_threading_tasks():
vars = dict(VERSION="8.0", TOPOLOGY="replica_set")
server_func = FunctionCall(func="run server", vars=vars)
test_func = FunctionCall(func="run tests")
task_name = "test-free-threading"
tags = ["free-threading"]
return [EvgTask(name=task_name, tags=tags, commands=[server_func, test_func])]
##############
# Functions
##############
@ -964,6 +1091,26 @@ def create_upload_coverage_func():
return "upload coverage", [get_assume_role(), cmd]
def create_upload_coverage_codecov_func():
# Upload the coverage xml report to codecov.
include_expansions = [
"CODECOV_TOKEN",
"build_variant",
"task_name",
"github_commit",
"github_pr_number",
"github_pr_head_branch",
"github_author",
"requester",
"branch_name",
]
args = [
".evergreen/scripts/upload-codecov.sh",
]
upload_cmd = get_subprocess_exec(include_expansions_in_env=include_expansions, args=args)
return "upload codecov", [upload_cmd]
def create_download_and_merge_coverage_func():
include_expansions = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]
args = [
@ -1046,8 +1193,8 @@ def create_run_server_func():
"AUTH",
"SSL",
"ORCHESTRATION_FILE",
"PYTHON_BINARY",
"PYTHON_VERSION",
"UV_PYTHON",
"TOOLCHAIN_VERSION",
"STORAGE_ENGINE",
"REQUIRE_API_VERSION",
"DRIVERS_TOOLS",
@ -1071,10 +1218,10 @@ def create_run_tests_func():
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
"COVERAGE",
"PYTHON_BINARY",
"UV_PYTHON",
"LIBMONGOCRYPT_URL",
"MONGODB_URI",
"PYTHON_VERSION",
"TOOLCHAIN_VERSION",
"DISABLE_TEST_COMMANDS",
"GREEN_FRAMEWORK",
"NO_EXT",
@ -1088,6 +1235,7 @@ def create_run_tests_func():
"VERSION",
"IS_WIN32",
"REQUIRE_FIPS",
"TEST_MIN_DEPS",
]
args = [".evergreen/just.sh", "setup-tests", "${TEST_NAME}", "${SUB_TEST_NAME}"]
setup_cmd = get_subprocess_exec(include_expansions_in_env=includes, args=args)
@ -1095,6 +1243,14 @@ def create_run_tests_func():
return "run tests", [setup_cmd, test_cmd]
def create_test_numpy_func():
includes = ["TOOLCHAIN_VERSION", "COVERAGE"]
test_cmd = get_subprocess_exec(
include_expansions_in_env=includes, args=[".evergreen/just.sh", "test-numpy"]
)
return "test numpy", [test_cmd]
def create_cleanup_func():
cmd = get_subprocess_exec(args=[".evergreen/scripts/cleanup.sh"])
return "cleanup", [cmd]

View File

@ -22,11 +22,13 @@ from shrub.v3.shrub_service import ShrubService
##############
ALL_VERSIONS = ["4.2", "4.4", "5.0", "6.0", "7.0", "8.0", "rapid", "latest"]
CPYTHONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
PYPYS = ["pypy3.10"]
CPYTHONS = ["3.10", "3.11", "3.12", "3.13", "3.14t", "3.14"]
PYPYS = ["pypy3.11"]
MIN_SUPPORT_VERSIONS = ["3.9", "pypy3.9", "pypy3.10"]
ALL_PYTHONS = CPYTHONS + PYPYS
MIN_MAX_PYTHON = [CPYTHONS[0], CPYTHONS[-1]]
BATCHTIME_WEEK = 10080
BATCHTIME_DAY = 1440
AUTH_SSLS = [("auth", "ssl"), ("noauth", "ssl"), ("noauth", "nossl")]
TOPOLOGIES = ["standalone", "replica_set", "sharded_cluster"]
C_EXTS = ["without_ext", "with_ext"]
@ -41,6 +43,7 @@ DISPLAY_LOOKUP = dict(
sync={"sync": "Sync", "async": "Async"},
coverage={"1": "cov"},
no_ext={"1": "No C"},
test_min_deps={"1": "Min Deps"},
)
HOSTS = dict()
@ -56,12 +59,12 @@ class Host:
# Hosts with toolchains.
HOSTS["rhel8"] = Host("rhel8", "rhel87-small", "RHEL8", dict())
HOSTS["win64"] = Host("win64", "windows-64-vsMulti-small", "Win64", dict())
HOSTS["win-latest"] = Host("win-latest", "windows-2022-latest-small", "WinLatest", dict())
HOSTS["win32"] = Host("win32", "windows-64-vsMulti-small", "Win32", dict())
HOSTS["macos"] = Host("macos", "macos-14", "macOS", dict())
HOSTS["macos-arm64"] = Host("macos-arm64", "macos-14-arm64", "macOS Arm64", dict())
HOSTS["ubuntu20"] = Host("ubuntu20", "ubuntu2004-small", "Ubuntu-20", dict())
HOSTS["ubuntu22"] = Host("ubuntu22", "ubuntu2204-small", "Ubuntu-22", dict())
HOSTS["rhel7"] = Host("rhel7", "rhel79-small", "RHEL7", dict())
HOSTS["ubuntu24"] = Host("ubuntu24", "ubuntu2404-small", "Ubuntu-24", dict())
HOSTS["perf"] = Host("perf", "rhel90-dbx-perf-large", "", dict())
HOSTS["debian11"] = Host("debian11", "debian11-small", "Debian11", dict())
DEFAULT_HOST = HOSTS["rhel8"]
@ -131,43 +134,25 @@ def create_variant(
*,
version: str | None = None,
host: Host | str | None = None,
python: str | None = None,
expansions: dict | None = None,
**kwargs: Any,
) -> BuildVariant:
expansions = expansions and expansions.copy() or dict()
if version:
expansions["VERSION"] = version
if python:
expansions["PYTHON_BINARY"] = get_python_binary(python, host)
# 8.0+ Windows builds must run on win-latest
if (
"win64" in display_name.lower()
or "win32" in display_name.lower()
and version
and version >= "8.0"
):
kwargs["run_on"] = HOSTS["win-latest"].run_on
return create_variant_generic(
tasks, display_name, version=version, host=host, expansions=expansions, **kwargs
)
def get_python_binary(python: str, host: Host) -> str:
"""Get the appropriate python binary given a python version and host."""
name = host.name
if name in ["win64", "win32"]:
if name == "win32":
base = "C:/python/32"
else:
base = "C:/python"
python_dir = python.replace(".", "").replace("t", "")
return f"{base}/Python{python_dir}/python{python}.exe"
if name in ["rhel8", "ubuntu22", "ubuntu20", "rhel7"]:
return f"/opt/python/{python}/bin/python3"
if name in ["macos", "macos-arm64"]:
bin_name = "python3t" if "t" in python else "python3"
python_dir = python.replace("t", "")
framework_dir = "PythonT" if "t" in python else "Python"
return f"/Library/Frameworks/{framework_dir}.Framework/Versions/{python_dir}/bin/{bin_name}"
raise ValueError(f"no match found for python {python} on {name}")
def get_versions_from(min_version: str) -> list[str]:
"""Get all server versions starting from a minimum version."""
min_version_float = float(min_version)
@ -196,12 +181,12 @@ def get_common_name(base: str, sep: str, **kwargs) -> str:
display_name = f"{display_name}{sep}{version}"
for key, value in kwargs.items():
name = value
if key.lower() == "python":
if key.lower() in ["python", "toolchain_version"]:
if not value.startswith("pypy"):
name = f"Python{value}"
else:
name = f"PyPy{value.replace('pypy', '')}"
elif key.lower() in DISPLAY_LOOKUP:
elif key.lower() in DISPLAY_LOOKUP and value in DISPLAY_LOOKUP[key.lower()]:
name = DISPLAY_LOOKUP[key.lower()][value]
else:
continue
@ -273,7 +258,7 @@ def generate_yaml(tasks=None, variants=None):
out = ShrubService.generate_yaml(project)
# Dedent by two spaces to match what we use in config.yml
lines = [line[2:] for line in out.splitlines()]
print("\n".join(lines)) # noqa: T201
print("\n".join(lines))
##################

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Install the dependencies needed for an evergreen run.
# Install the necessary dependencies.
set -eu
HERE=$(dirname ${BASH_SOURCE:-$0})
@ -13,50 +13,6 @@ fi
# Set up the default bin directory.
if [ -z "${PYMONGO_BIN_DIR:-}" ]; then
PYMONGO_BIN_DIR="$HOME/.local/bin"
export PATH="$PYMONGO_BIN_DIR:$PATH"
fi
# Helper function to pip install a dependency using a temporary python env.
function _pip_install() {
_HERE=$(dirname ${BASH_SOURCE:-$0})
. $_HERE/../utils.sh
_VENV_PATH=$(mktemp -d)
if [ "Windows_NT" = "${OS:-}" ]; then
_VENV_PATH=$(cygpath -m $_VENV_PATH)
fi
echo "Installing $2 using pip..."
createvirtualenv "$(find_python3)" $_VENV_PATH
python -m pip install $1
_suffix=""
if [ "Windows_NT" = "${OS:-}" ]; then
_suffix=".exe"
fi
ln -s "$(which $2)" $PYMONGO_BIN_DIR/${2}${_suffix}
# uv also comes with a uvx binary.
if [ $2 == "uv" ]; then
ln -s "$(which uvx)" $PYMONGO_BIN_DIR/uvx${_suffix}
fi
echo "Installed to ${PYMONGO_BIN_DIR}"
echo "Installing $2 using pip... done."
}
# Ensure just is installed.
if ! command -v just &>/dev/null; then
# On most systems we can install directly.
_TARGET=""
if [ "Windows_NT" = "${OS:-}" ]; then
_TARGET="--target x86_64-pc-windows-msvc"
fi
_BIN_DIR=$PYMONGO_BIN_DIR
mkdir -p ${_BIN_DIR}
echo "Installing just..."
mkdir -p "$_BIN_DIR" 2>/dev/null || true
curl --proto '=https' --tlsv1.2 -sSf https://just.systems/install.sh | bash -s -- $_TARGET --to "$_BIN_DIR" || {
# Remove just file if it exists (can be created if there was an install error).
rm -f ${_BIN_DIR}/just
_pip_install rust-just just
}
echo "Installing just... done."
fi
# Ensure uv is installed.
@ -64,14 +20,17 @@ if ! command -v uv &>/dev/null; then
_BIN_DIR=$PYMONGO_BIN_DIR
mkdir -p ${_BIN_DIR}
echo "Installing uv..."
# On most systems we can install directly.
curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR="$_BIN_DIR" INSTALLER_NO_MODIFY_PATH=1 sh || {
_pip_install uv uv
}
curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR="$_BIN_DIR" INSTALLER_NO_MODIFY_PATH=1 sh
if [ "Windows_NT" = "${OS:-}" ]; then
chmod +x "$(cygpath -u $_BIN_DIR)/uv.exe"
fi
export PATH="$PYMONGO_BIN_DIR:$PATH"
echo "Installing uv... done."
fi
# Ensure just is installed.
if ! command -v just &>/dev/null; then
uv tool install rust-just
fi
popd > /dev/null

View File

@ -98,6 +98,13 @@ def setup_kms(sub_test_name: str) -> None:
if sub_test_target == "azure":
os.environ["AZUREKMS_VMNAME_PREFIX"] = "PYTHON_DRIVER"
# Found using "az vm image list --output table"
os.environ[
"AZUREKMS_IMAGE"
] = "Canonical:0001-com-ubuntu-server-jammy:22_04-lts-gen2:latest"
else:
os.environ["GCPKMS_IMAGEFAMILY"] = "debian-12"
run_command("./setup.sh", cwd=kms_dir)
base_env = _load_kms_config(sub_test_target)

View File

@ -42,6 +42,11 @@ def setup_oidc(sub_test_name: str) -> dict[str, str] | None:
if sub_test_name == "azure":
env["AZUREOIDC_VMNAME_PREFIX"] = "PYTHON_DRIVER"
if "-remote" not in sub_test_name:
if sub_test_name == "azure":
# Found using "az vm image list --output table"
env["AZUREOIDC_IMAGE"] = "Canonical:0001-com-ubuntu-server-jammy:22_04-lts-gen2:latest"
else:
env["GCPKMS_IMAGEFAMILY"] = "debian-12"
run_command(f"bash {target_dir}/setup.sh", env=env)
if sub_test_name in K8S_NAMES:
run_command(f"bash {target_dir}/setup-pod.sh {sub_test_name}")

View File

@ -6,12 +6,13 @@ import pathlib
import subprocess
from argparse import Namespace
from subprocess import CalledProcessError
from typing import Optional
JIRA_FILTER = "https://jira.mongodb.org/issues/?jql=labels%20%3D%20automated-sync%20AND%20status%20!%3D%20Closed"
def resync_specs(directory: pathlib.Path, errored: dict[str, str]) -> None:
"""Actually sync the specs"""
print("Beginning to sync specs") # noqa: T201
print("Beginning to sync specs")
for spec in os.scandir(directory):
if not spec.is_dir():
continue
@ -27,17 +28,34 @@ def resync_specs(directory: pathlib.Path, errored: dict[str, str]) -> None:
)
except CalledProcessError as exc:
errored[spec.name] = exc.stderr
print("Done syncing specs") # noqa: T201
print("Done syncing specs")
def apply_patches():
print("Beginning to apply patches") # noqa: T201
subprocess.run(["bash", "./.evergreen/remove-unimplemented-tests.sh"], check=True) # noqa: S603, S607
def apply_patches(errored):
print("Beginning to apply patches")
subprocess.run(
["git apply -R --allow-empty --whitespace=fix ./.evergreen/spec-patch/*"], # noqa: S607
shell=True, # noqa: S602
["bash", "./.evergreen/remove-unimplemented-tests.sh"], # noqa: S603, S607
check=True,
)
try:
# Avoid shell=True by passing arguments as a list.
# Note: glob expansion doesn't work in shell=False, so we use a list of files.
patches = [str(p) for p in pathlib.Path("./.evergreen/spec-patch/").glob("*")]
if patches:
subprocess.run(
[ # noqa: S603, S607
"git",
"apply",
"-R",
"--allow-empty",
"--whitespace=fix",
*patches,
],
check=True,
stderr=subprocess.PIPE,
)
except CalledProcessError as exc:
errored["applying patches"] = exc.stderr
def check_new_spec_directories(directory: pathlib.Path) -> list[str]:
@ -56,7 +74,6 @@ def check_new_spec_directories(directory: pathlib.Path) -> list[str]:
"client_side_operations_timeout": "csot",
"mongodb_handshake": "handshake",
"load_balancers": "load_balancer",
"atlas_data_lake_testing": "atlas",
"connection_monitoring_and_pooling": "connection_monitoring",
"command_logging_and_monitoring": "command_logging",
"initial_dns_seedlist_discovery": "srv_seedlist",
@ -70,23 +87,30 @@ def check_new_spec_directories(directory: pathlib.Path) -> list[str]:
return list(spec_set - test_set)
def write_summary(errored: dict[str, str], new: list[str], filename: Optional[str]) -> None:
def write_summary(errored: dict[str, str], new: list[str], filename: str | None) -> None:
"""Generate the PR description"""
pr_body = ""
# Avoid shell=True and complex pipes by using Python to process git output
process = subprocess.run(
["git diff --name-only | awk -F'/' '{print $2}' | sort | uniq"], # noqa: S607
shell=True, # noqa: S602
["git", "diff", "--name-only"], # noqa: S603, S607
capture_output=True,
text=True,
check=True,
)
succeeded = process.stdout.strip().split()
changed_files = process.stdout.strip().splitlines()
succeeded_set = set()
for f in changed_files:
parts = f.split("/")
if len(parts) > 1:
succeeded_set.add(parts[1])
succeeded = sorted(succeeded_set)
if len(succeeded) > 0:
pr_body += "The following specs were changed:\n -"
pr_body += "\n -".join(succeeded)
pr_body += "\n"
if len(errored) > 0:
pr_body += "\n\nThe following spec syncs encountered errors:\n -"
pr_body += "\n\nThe following spec syncs encountered errors:"
for k, v in errored.items():
pr_body += f"\n -{k}\n```{v}\n```"
pr_body += "\n"
@ -95,8 +119,9 @@ def write_summary(errored: dict[str, str], new: list[str], filename: Optional[st
pr_body += "\n -".join(new)
pr_body += "\n"
if pr_body != "":
pr_body = f"Jira tickets: {JIRA_FILTER}\n\n" + pr_body
if filename is None:
print(f"\n{pr_body}") # noqa: T201
print(f"\n{pr_body}")
else:
with open(filename, "w") as f:
# replacements made for proper json
@ -107,7 +132,7 @@ def main(args: Namespace):
directory = pathlib.Path("./test")
errored: dict[str, str] = {}
resync_specs(directory, errored)
apply_patches()
apply_patches(errored)
new = check_new_spec_directories(directory)
write_summary(errored, new, args.filename)
@ -117,7 +142,9 @@ if __name__ == "__main__":
description="Python Script to resync all specs and generate summary for PR."
)
parser.add_argument(
"--filename", help="Name of file for the summary to be written into.", default=None
"--filename",
help="Name of file for the summary to be written into.",
default=None,
)
args = parser.parse_args()
main(args)

View File

@ -1,5 +1,6 @@
#!/usr/bin/env bash
# Run spec syncing script and create PR
set -eu
# SETUP
SRC_URL="https://github.com/mongodb/specifications.git"

View File

@ -12,7 +12,7 @@ def set_env(name: str, value: Any = "1") -> None:
def start_server():
opts, extra_opts = get_test_options(
"Run a MongoDB server. All given flags will be passed to run-orchestration.sh in DRIVERS_TOOLS.",
"Run a MongoDB server. All given flags will be passed to run-mongodb.sh in DRIVERS_TOOLS.",
require_sub_test_name=False,
allow_extra_opts=True,
)
@ -51,7 +51,7 @@ def start_server():
elif opts.quiet:
extra_opts.append("-q")
cmd = ["bash", f"{DRIVERS_TOOLS}/.evergreen/run-orchestration.sh", *extra_opts]
cmd = ["bash", f"{DRIVERS_TOOLS}/.evergreen/run-mongodb.sh", "start", *extra_opts]
run_command(cmd, cwd=DRIVERS_TOOLS)

View File

@ -4,12 +4,20 @@ import json
import logging
import os
import platform
import shlex
import shutil
import subprocess
import sys
from datetime import datetime
from pathlib import Path
from shutil import which
try:
import importlib_metadata
except ImportError:
from importlib import metadata as importlib_metadata
import pytest
from utils import DRIVERS_TOOLS, LOGGER, ROOT, run_command
@ -23,6 +31,22 @@ TEST_NAME = os.environ.get("TEST_NAME")
SUB_TEST_NAME = os.environ.get("SUB_TEST_NAME")
def list_packages():
packages = set()
for distribution in importlib_metadata.distributions():
if distribution.name:
packages.add(distribution.name)
print("Package Version URL")
print("------------------- ----------- ----------------------------------------------------")
for name in sorted(packages):
distribution = importlib_metadata.distribution(name)
url = ""
if distribution.origin is not None:
url = distribution.origin.url
print(f"{name:20s}{distribution.version:12s}{url}")
print("------------------- ----------- ----------------------------------------------------\n")
def handle_perf(start_time: datetime):
end_time = datetime.now()
elapsed_secs = (end_time - start_time).total_seconds()
@ -46,13 +70,7 @@ def handle_perf(start_time: datetime):
def handle_green_framework() -> None:
if GREEN_FRAMEWORK == "eventlet":
import eventlet
# https://github.com/eventlet/eventlet/issues/401
eventlet.sleep()
eventlet.monkey_patch()
elif GREEN_FRAMEWORK == "gevent":
if GREEN_FRAMEWORK == "gevent":
from gevent import monkey
monkey.patch_all()
@ -90,10 +108,11 @@ def handle_aws_lambda() -> None:
env["TEST_LAMBDA_DIRECTORY"] = str(target_dir)
env.setdefault("AWS_REGION", "us-east-1")
dirs = ["pymongo", "gridfs", "bson"]
# Store the original .so files.
before_sos = []
# Remove the original .so files.
for dname in dirs:
before_sos.extend(f"{f.parent.name}/{f.name}" for f in (ROOT / dname).glob("*.so"))
so_paths = [f"{f.parent.name}/{f.name}" for f in (ROOT / dname).glob("*.so")]
for so_path in list(so_paths):
Path(so_path).unlink()
# Build the c extensions.
docker = which("docker") or which("podman")
if not docker:
@ -106,21 +125,23 @@ def handle_aws_lambda() -> None:
target = ROOT / "test/lambda/mongodb" / dname
shutil.rmtree(target, ignore_errors=True)
shutil.copytree(ROOT / dname, target)
# Remove the original so files from the lambda directory.
for so_path in before_sos:
(ROOT / "test/lambda/mongodb" / so_path).unlink()
# Remove the new so files from the ROOT directory.
for dname in dirs:
so_paths = [f"{f.parent.name}/{f.name}" for f in (ROOT / dname).glob("*.so")]
for so_path in list(so_paths):
if so_path not in before_sos:
Path(so_path).unlink()
Path(so_path).unlink()
script_name = "run-deployed-lambda-aws-tests.sh"
run_command(f"bash {DRIVERS_TOOLS}/.evergreen/aws_lambda/{script_name}", env=env)
def run() -> None:
# Add diagnostic for python version.
print("Running with python", sys.version)
# List the installed packages.
list_packages()
# Handle green framework first so they can patch modules.
if GREEN_FRAMEWORK:
handle_green_framework()
@ -183,6 +204,16 @@ def run() -> None:
if os.environ.get("DEBUG_LOG"):
TEST_ARGS.extend(f"-o log_cli_level={logging.DEBUG}".split())
if os.environ.get("COVERAGE"):
binary = sys.executable.replace(os.sep, "/")
cmd = f"{binary} -m coverage run -m pytest {' '.join(TEST_ARGS)} {' '.join(sys.argv[1:])}"
result = subprocess.run(shlex.split(cmd), check=False) # noqa: S603
cmd = f"{binary} -m coverage report"
subprocess.run(shlex.split(cmd), check=False) # noqa: S603
if result.returncode != 0:
print(result.stderr)
sys.exit(result.returncode)
# Run local tests.
ret = pytest.main(TEST_ARGS + sys.argv[1:])
if ret != 0:

View File

@ -1,17 +1,17 @@
#!/bin/bash
# Set up a development environment on an evergreen host.
# Set up development environment.
set -eu
HERE=$(dirname ${BASH_SOURCE:-$0})
HERE="$( cd -- "$HERE" > /dev/null 2>&1 && pwd )"
ROOT=$(dirname "$(dirname $HERE)")
pushd $ROOT > /dev/null
# Source the env files to pick up common variables.
if [ -f $HERE/env.sh ]; then
. $HERE/env.sh
fi
# PYTHON_BINARY or PYTHON_VERSION may be defined in test-env.sh.
# Get variables defined in test-env.sh.
if [ -f $HERE/test-env.sh ]; then
. $HERE/test-env.sh
fi
@ -19,41 +19,40 @@ fi
# Ensure dependencies are installed.
bash $HERE/install-dependencies.sh
# Get the appropriate UV_PYTHON.
. $ROOT/.evergreen/utils.sh
# Handle the value for UV_PYTHON.
. $HERE/setup-uv-python.sh
if [ -z "${PYTHON_BINARY:-}" ]; then
if [ -n "${PYTHON_VERSION:-}" ]; then
PYTHON_BINARY=$(get_python_binary $PYTHON_VERSION)
else
PYTHON_BINARY=$(find_python3)
# Only run the next part if not running on CI.
if [ -z "${CI:-}" ]; then
# Add the default install path to the path if needed.
if [ -z "${PYMONGO_BIN_DIR:-}" ]; then
export PATH="$PATH:$HOME/.local/bin"
fi
# Set up venv, making sure c extensions build unless disabled.
if [ -z "${NO_EXT:-}" ]; then
export PYMONGO_C_EXT_MUST_BUILD=1
fi
(
cd $ROOT && uv sync
)
# Set up build utilities on Windows spawn hosts.
if [ -f $HOME/.visualStudioEnv.sh ]; then
set +u
SSH_TTY=1 source $HOME/.visualStudioEnv.sh
set -u
fi
# Only set up pre-commit if we are in a git checkout.
if [ -f $HERE/.git ]; then
if ! command -v pre-commit &>/dev/null; then
uv tool install pre-commit
fi
fi
export UV_PYTHON=${PYTHON_BINARY}
echo "Using python $UV_PYTHON"
# Add the default install path to the path if needed.
if [ -z "${PYMONGO_BIN_DIR:-}" ]; then
export PATH="$PATH:$HOME/.local/bin"
if [ ! -f .git/hooks/pre-commit ]; then
uvx pre-commit install
fi
fi
fi
# Set up venv, making sure c extensions build unless disabled.
if [ -z "${NO_EXT:-}" ]; then
export PYMONGO_C_EXT_MUST_BUILD=1
fi
# Set up visual studio env on Windows spawn hosts.
if [ -f $HOME/.visualStudioEnv.sh ]; then
set +u
SSH_TTY=1 source $HOME/.visualStudioEnv.sh
set -u
fi
uv sync --frozen
echo "Setting up python environment... done."
# Ensure there is a pre-commit hook if there is a git checkout.
if [ -d .git ] && [ ! -f .git/hooks/pre-commit ]; then
uv run --frozen pre-commit install
fi
popd > /dev/null

View File

@ -8,9 +8,13 @@ echo "Setting up system..."
bash .evergreen/scripts/configure-env.sh
source .evergreen/scripts/env.sh
bash $DRIVERS_TOOLS/.evergreen/setup.sh
bash .evergreen/scripts/install-dependencies.sh
popd
# Run spawn host-specific tasks.
if [ -z "${CI:-}" ]; then
bash $HERE/setup-dev-env.sh
fi
# Enable core dumps if enabled on the machine
# Copied from https://github.com/mongodb/mongo/blob/master/etc/evergreen.yml
if [ -f /proc/self/coredump_filter ]; then

View File

@ -12,6 +12,7 @@ set -eu
# TEST_CRYPT_SHARED If non-empty, install crypt_shared lib.
# MONGODB_API_VERSION The mongodb api version to use in tests.
# MONGODB_URI If non-empty, use as the MONGODB_URI in tests.
# USE_ACTIVE_VENV If non-empty, use the active virtual environment.
SCRIPT_DIR=$(dirname ${BASH_SOURCE:-$0})
@ -21,5 +22,5 @@ if [ -f $SCRIPT_DIR/env.sh ]; then
fi
echo "Setting up tests with args \"$*\"..."
uv run $SCRIPT_DIR/setup_tests.py "$@"
uv run ${USE_ACTIVE_VENV:+--active} "$SCRIPT_DIR/setup_tests.py" "$@"
echo "Setting up tests with args \"$*\"... done."

View File

@ -0,0 +1,53 @@
#!/bin/bash
# Set up the UV_PYTHON variable.
set -eu
HERE=$(dirname ${BASH_SOURCE:-$0})
HERE="$( cd -- "$HERE" > /dev/null 2>&1 && pwd )"
# Use min supported version by default.
_python="3.10"
# Source the env files to pick up common variables.
if [ -f $HERE/env.sh ]; then
. $HERE/env.sh
fi
# Get variables defined in test-env.sh.
if [ -f $HERE/test-env.sh ]; then
. $HERE/test-env.sh
fi
if [ -z "${UV_PYTHON:-}" ]; then
set -x
# Translate a TOOLCHAIN_VERSION to UV_PYTHON.
if [ -n "${TOOLCHAIN_VERSION:-}" ]; then
_python=$TOOLCHAIN_VERSION
if [ "$(uname -s)" = "Darwin" ]; then
if [[ "$_python" == *"t"* ]]; then
binary_name="python3t"
framework_dir="PythonT"
else
binary_name="python3"
framework_dir="Python"
fi
_python=$(echo "$_python" | sed 's/t//g')
_python="/Library/Frameworks/$framework_dir.Framework/Versions/$_python/bin/$binary_name"
elif [ "Windows_NT" = "${OS:-}" ]; then
_python=$(echo $_python | cut -d. -f1,2 | sed 's/\.//g; s/t//g')
if [[ "$TOOLCHAIN_VERSION" == *"t"* ]]; then
_exe="python${TOOLCHAIN_VERSION}.exe"
else
_exe="python.exe"
fi
if [ -n "${IS_WIN32:-}" ]; then
_python="C:/python/32/Python${_python}/${_exe}"
else
_python="C:/python/Python${_python}/${_exe}"
fi
elif [ -d "/opt/python/$_python/bin" ]; then
_python="/opt/python/$_python/bin/python3"
fi
fi
export UV_PYTHON="$_python"
fi

View File

@ -1,12 +1,10 @@
from __future__ import annotations
import base64
import io
import os
import platform
import shutil
import stat
import tarfile
from pathlib import Path
from urllib import request
@ -31,8 +29,7 @@ PASS_THROUGH_ENV = [
"NO_EXT",
"MONGODB_API_VERSION",
"DEBUG_LOG",
"PYTHON_BINARY",
"PYTHON_VERSION",
"UV_PYTHON",
"REQUIRE_FIPS",
"IS_WIN32",
]
@ -53,7 +50,7 @@ EXTRAS_MAP = {
GROUP_MAP = dict(mockupdb="mockupdb", perf="perf")
# The python version used for perf tests.
PERF_PYTHON_VERSION = "3.9.13"
PERF_PYTHON_VERSION = "3.10.11"
def is_set(var: str) -> bool:
@ -90,6 +87,13 @@ def setup_libmongocrypt():
distro = get_distro()
if distro.name.startswith("Debian"):
target = f"debian{distro.version_id}"
elif distro.name.startswith("Ubuntu"):
if distro.version_id == "20.04":
target = "debian11"
elif distro.version_id == "22.04":
target = "debian12"
elif distro.version_id == "24.04":
target = "debian13"
elif distro.name.startswith("Red Hat"):
if distro.version_id.startswith("7"):
target = "rhel-70-64-bit"
@ -111,9 +115,10 @@ def setup_libmongocrypt():
LOGGER.info(f"Fetching {url}...")
with request.urlopen(request.Request(url), timeout=15.0) as response: # noqa: S310
if response.status == 200:
fileobj = io.BytesIO(response.read())
with tarfile.open("libmongocrypt.tar.gz", fileobj=fileobj) as fid:
fid.extractall(Path.cwd() / "libmongocrypt")
with Path("libmongocrypt.tar.gz").open("wb") as f:
f.write(response.read())
Path("libmongocrypt").mkdir()
run_command("tar -xzf libmongocrypt.tar.gz -C libmongocrypt")
LOGGER.info(f"Fetching {url}... done.")
run_command("ls -la libmongocrypt")
@ -148,6 +153,10 @@ def handle_test_env() -> None:
# Start compiling the args we'll pass to uv.
UV_ARGS = ["--extra test --no-group dev"]
# If USE_ACTIVE_VENV is set, add --active to UV_ARGS so run-tests.sh uses the active venv.
if is_set("USE_ACTIVE_VENV"):
UV_ARGS.append("--active")
test_title = test_name
if sub_test_name:
test_title += f" {sub_test_name}"
@ -160,7 +169,6 @@ def handle_test_env() -> None:
write_env("PIP_QUIET") # Quiet by default.
write_env("PIP_PREFER_BINARY") # Prefer binary dists by default.
write_env("UV_FROZEN") # Do not modify lock files.
# Set an environment variable for the test name and sub test name.
write_env(f"TEST_{test_name.upper()}")
@ -178,6 +186,9 @@ def handle_test_env() -> None:
if group := GROUP_MAP.get(test_name, ""):
UV_ARGS.append(f"--group {group}")
if opts.test_min_deps:
UV_ARGS.append("--resolution=lowest-direct")
if test_name == "auth_oidc":
from oidc_tester import setup_oidc
@ -214,18 +225,8 @@ def handle_test_env() -> None:
if key in os.environ:
write_env(key, os.environ[key])
if test_name == "data_lake":
# Stop any running mongo-orchestration which might be using the port.
run_command(f"bash {DRIVERS_TOOLS}/.evergreen/stop-orchestration.sh")
run_command(f"bash {DRIVERS_TOOLS}/.evergreen/atlas_data_lake/setup.sh")
AUTH = "auth"
if AUTH != "noauth":
if test_name == "data_lake":
config = read_env(f"{DRIVERS_TOOLS}/.evergreen/atlas_data_lake/secrets-export.sh")
DB_USER = config["ADL_USERNAME"]
DB_PASSWORD = config["ADL_PASSWORD"]
elif test_name == "auth_oidc":
if test_name == "auth_oidc":
DB_USER = config["OIDC_ADMIN_USER"]
DB_PASSWORD = config["OIDC_ADMIN_PWD"]
elif test_name == "search_index":
@ -243,7 +244,7 @@ def handle_test_env() -> None:
if is_set("MONGODB_URI"):
write_env("PYMONGO_MUST_CONNECT", "true")
if is_set("DISABLE_TEST_COMMANDS") or opts.disable_test_commands:
if opts.disable_test_commands:
write_env("PYMONGO_DISABLE_TEST_COMMANDS", "1")
if test_name == "enterprise_auth":
@ -327,7 +328,8 @@ def handle_test_env() -> None:
version = os.environ.get("VERSION", "latest")
cmd = [
"bash",
f"{DRIVERS_TOOLS}/.evergreen/run-orchestration.sh",
f"{DRIVERS_TOOLS}/.evergreen/run-mongodb.sh",
"start",
"--ssl",
"--version",
version,
@ -355,8 +357,10 @@ def handle_test_env() -> None:
if not (ROOT / "libmongocrypt").exists():
setup_libmongocrypt()
# TODO: Test with 'pip install pymongocrypt'
UV_ARGS.append("--group pymongocrypt_source")
if not opts.test_min_deps:
UV_ARGS.append(
"--with pymongocrypt@git+https://github.com/mongodb/libmongocrypt@master#subdirectory=bindings/python"
)
# Use the nocrypto build to avoid dependency issues with older windows/python versions.
BASE = ROOT / "libmongocrypt/nocrypto"
@ -385,7 +389,7 @@ def handle_test_env() -> None:
if sub_test_name == "pyopenssl":
UV_ARGS.append("--extra ocsp")
if is_set("TEST_CRYPT_SHARED") or opts.crypt_shared:
if opts.crypt_shared:
config = read_env(f"{DRIVERS_TOOLS}/mo-expansion.sh")
CRYPT_SHARED_DIR = Path(config["CRYPT_SHARED_LIB_PATH"]).parent.as_posix()
LOGGER.info("Using crypt_shared_dir %s", CRYPT_SHARED_DIR)
@ -432,6 +436,9 @@ def handle_test_env() -> None:
# We do not want the default client_context to be initialized.
write_env("DISABLE_CONTEXT")
if test_name == "numpy":
UV_ARGS.append("--with numpy")
if test_name == "perf":
data_dir = ROOT / "specifications/source/benchmarking/data"
if not data_dir.exists():
@ -455,16 +462,18 @@ def handle_test_env() -> None:
# Add coverage if requested.
# Only cover CPython. PyPy reports suspiciously low coverage.
if (is_set("COVERAGE") or opts.cov) and platform.python_implementation() == "CPython":
if opts.cov and platform.python_implementation() == "CPython":
# Keep in sync with combine-coverage.sh.
# coverage >=5 is needed for relative_files=true.
UV_ARGS.append("--group coverage")
TEST_ARGS = f"{TEST_ARGS} --cov"
write_env("COVERAGE")
if is_set("GREEN_FRAMEWORK") or opts.green_framework:
if opts.green_framework:
framework = opts.green_framework or os.environ["GREEN_FRAMEWORK"]
UV_ARGS.append(f"--group {framework}")
if framework == "gevent" and opts.test_min_deps:
# PYTHON-5729. This can be removed when the min supported gevent is moved to 25.9.1.
UV_ARGS.append('--with "setuptools==81.0"')
else:
TEST_ARGS = f"-v --durations=5 {TEST_ARGS}"

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Stop a server that was started using run-orchestration.sh in DRIVERS_TOOLS.
# Stop a server that was started using run-mongodb.sh in DRIVERS_TOOLS.
set -eu
HERE=$(dirname ${BASH_SOURCE:-$0})
@ -11,4 +11,4 @@ if [ -f $HERE/env.sh ]; then
source $HERE/env.sh
fi
bash ${DRIVERS_TOOLS}/.evergreen/stop-orchestration.sh
bash ${DRIVERS_TOOLS}/.evergreen/run-mongodb.sh stop

View File

@ -57,10 +57,6 @@ elif TEST_NAME == "mod_wsgi":
teardown_mod_wsgi()
# Tear down data_lake if applicable.
elif TEST_NAME == "data_lake":
run_command(f"{DRIVERS_TOOLS}/.evergreen/atlas_data_lake/teardown.sh")
# Tear down coverage if applicable.
if os.environ.get("COVERAGE"):
shutil.rmtree(".pytest_cache", ignore_errors=True)

View File

@ -0,0 +1,57 @@
#!/bin/bash
# shellcheck disable=SC2154
# Upload a coverate report to codecov.
set -eu
HERE=$(dirname ${BASH_SOURCE:-$0})
ROOT=$(dirname "$(dirname $HERE)")
pushd $ROOT > /dev/null
export FNAME=coverage.xml
REQUESTER=${requester:-}
if [ ! -f ".coverage" ]; then
echo "There are no coverage results, not running codecov"
exit 0
fi
if [[ "${REQUESTER}" == "github_pr" || "${REQUESTER}" == "commit" ]]; then
echo "Uploading codecov for $REQUESTER..."
else
echo "Error: requester must be 'github_pr' or 'commit', got '${REQUESTER}'" >&2
exit 1
fi
printf 'sha: %s\n' "$github_commit"
printf 'flag: %s-%s\n' "$build_variant" "$task_name"
printf 'file: %s\n' "$FNAME"
uv tool run --with "coverage[toml]" coverage xml
codecov_args=(
upload-process
--report-type coverage
--disable-search
--fail-on-error
--git-service github
--token "${CODECOV_TOKEN}"
--sha "${github_commit}"
--flag "${build_variant}-${task_name}"
--file "${FNAME}"
)
if [ -n "${github_pr_number:-}" ]; then
printf 'branch: %s:%s\n' "$github_author" "$github_pr_head_branch"
printf 'pr: %s\n' "$github_pr_number"
uv tool run --from codecov-cli codecovcli \
"${codecov_args[@]}" \
--pr "${github_pr_number}" \
--branch "${github_author}:${github_pr_head_branch}"
else
printf 'branch: %s\n' "$branch_name"
uv tool run --from codecov-cli codecovcli \
"${codecov_args[@]}" \
--branch "${branch_name}"
fi
echo "Uploading codecov for $REQUESTER... done."
popd > /dev/null

View File

@ -33,7 +33,6 @@ TEST_SUITE_MAP = {
"atlas_connect": "atlas_connect",
"auth_aws": "auth_aws",
"auth_oidc": "auth_oidc",
"data_lake": "data_lake",
"default": "",
"default_async": "default_async",
"default_sync": "default",
@ -45,6 +44,7 @@ TEST_SUITE_MAP = {
"mockupdb": "mockupdb",
"ocsp": "ocsp",
"perf": "perf",
"numpy": "",
}
# Tests that require a sub test suite.
@ -52,16 +52,18 @@ SUB_TEST_REQUIRED = ["auth_aws", "auth_oidc", "kms", "mod_wsgi", "perf"]
EXTRA_TESTS = ["mod_wsgi", "aws_lambda", "doctest"]
# Tests that do not use run-orchestration directly.
# Tests that do not use run-mongodb directly.
NO_RUN_ORCHESTRATION = [
"auth_oidc",
"atlas_connect",
"aws_lambda",
"data_lake",
"mockupdb",
"ocsp",
]
# Mapping of env variables to options
OPTION_TO_ENV_VAR = {"cov": "COVERAGE", "crypt_shared": "TEST_CRYPT_SHARED"}
def get_test_options(
description, require_sub_test_name=True, allow_extra_opts=False
@ -96,6 +98,9 @@ def get_test_options(
)
parser.add_argument("--auth", action="store_true", help="Whether to add authentication.")
parser.add_argument("--ssl", action="store_true", help="Whether to add TLS configuration.")
parser.add_argument(
"--test-min-deps", action="store_true", help="Test against minimum dependency versions"
)
# Add the test modifiers.
if require_sub_test_name:
@ -106,7 +111,7 @@ def get_test_options(
parser.add_argument(
"--green-framework",
nargs=1,
choices=["eventlet", "gevent"],
choices=["gevent"],
help="Optional green framework to test against.",
)
parser.add_argument(
@ -129,26 +134,53 @@ def get_test_options(
opts, extra_opts = parser.parse_args(), []
else:
opts, extra_opts = parser.parse_known_args()
if opts.verbose:
LOGGER.setLevel(logging.DEBUG)
elif opts.quiet:
LOGGER.setLevel(logging.WARNING)
# Convert list inputs to strings.
for name in vars(opts):
value = getattr(opts, name)
if isinstance(value, list):
setattr(opts, name, value[0])
# Handle validation and environment variable overrides.
test_name = opts.test_name
sub_test_name = opts.sub_test_name if require_sub_test_name else ""
if require_sub_test_name and test_name in SUB_TEST_REQUIRED and not sub_test_name:
raise ValueError(f"Test '{test_name}' requires a sub_test_name")
if "auth" in test_name or os.environ.get("AUTH") == "auth":
handle_env_overrides(parser, opts)
if "auth" in test_name:
opts.auth = True
# 'auth_aws ecs' shouldn't have extra auth set.
if test_name == "auth_aws" and sub_test_name == "ecs":
opts.auth = False
if os.environ.get("SSL") == "ssl":
opts.ssl = True
if opts.verbose:
LOGGER.setLevel(logging.DEBUG)
elif opts.quiet:
LOGGER.setLevel(logging.WARNING)
return opts, extra_opts
def handle_env_overrides(parser: argparse.ArgumentParser, opts: argparse.Namespace) -> None:
# Get the options, and then allow environment variable overrides.
for key in vars(opts):
if key in OPTION_TO_ENV_VAR:
env_var = OPTION_TO_ENV_VAR[key]
else:
env_var = key.upper()
if env_var in os.environ:
if parser.get_default(key) != getattr(opts, key):
LOGGER.info("Overriding env var '%s' with cli option", env_var)
elif env_var == "AUTH":
opts.auth = os.environ.get("AUTH") == "auth"
elif env_var == "SSL":
ssl_opt = os.environ.get("SSL", "")
opts.ssl = ssl_opt and ssl_opt.lower() != "nossl"
elif isinstance(getattr(opts, key), bool):
if os.environ[env_var]:
setattr(opts, key, True)
else:
setattr(opts, key, os.environ[env_var])
def read_env(path: Path | str) -> dict[str, str]:
config = dict()
with Path(path).open() as fid:

View File

@ -15,5 +15,4 @@ echo "Copying files to $target..."
rsync -az -e ssh --exclude '.git' --filter=':- .gitignore' -r . $target:$remote_dir
echo "Copying files to $target... done"
ssh $target $remote_dir/.evergreen/scripts/setup-system.sh
ssh $target "cd $remote_dir && PYTHON_BINARY=${PYTHON_BINARY:-} .evergreen/scripts/setup-dev-env.sh"
ssh $target "$remote_dir/.evergreen/scripts/setup-system.sh"

View File

@ -1,14 +1,14 @@
diff --git a/test/discovery_and_monitoring/unified/serverMonitoringMode.json b/test/discovery_and_monitoring/unified/serverMonitoringMode.json
index 4b492f7d8..e44fad1bc 100644
index e44fad1b..4b492f7d 100644
--- a/test/discovery_and_monitoring/unified/serverMonitoringMode.json
+++ b/test/discovery_and_monitoring/unified/serverMonitoringMode.json
@@ -5,8 +5,7 @@
@@ -5,7 +5,8 @@
{
"topologies": [
"single",
+ "sharded"
- "sharded",
- "sharded-replicaset"
- "sharded"
+ "sharded",
+ "sharded-replicaset"
],
"serverless": "forbid"
}

View File

@ -0,0 +1,440 @@
diff --git a/test/unified-test-format/invalid/entity-client-observeTracingMessages-additionalProperties.json b/test/unified-test-format/invalid/entity-client-observeTracingMessages-additionalProperties.json
new file mode 100644
index 00000000..aa8046d2
--- /dev/null
+++ b/test/unified-test-format/invalid/entity-client-observeTracingMessages-additionalProperties.json
@@ -0,0 +1,20 @@
+{
+ "description": "entity-client-observeTracingMessages-additionalProperties",
+ "schemaVersion": "1.26",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0",
+ "observeTracingMessages": {
+ "foo": "bar"
+ }
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "observeTracingMessages must not have additional properties'",
+ "operations": []
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/entity-client-observeTracingMessages-additionalPropertyType.json b/test/unified-test-format/invalid/entity-client-observeTracingMessages-additionalPropertyType.json
new file mode 100644
index 00000000..0b3a65f5
--- /dev/null
+++ b/test/unified-test-format/invalid/entity-client-observeTracingMessages-additionalPropertyType.json
@@ -0,0 +1,20 @@
+{
+ "description": "entity-client-observeTracingMessages-additionalPropertyType",
+ "schemaVersion": "1.26",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0",
+ "observeTracingMessages": {
+ "enableCommandPayload": 0
+ }
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "observeTracingMessages enableCommandPayload must be boolean",
+ "operations": []
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/entity-client-observeTracingMessages-type.json b/test/unified-test-format/invalid/entity-client-observeTracingMessages-type.json
new file mode 100644
index 00000000..de3ef39a
--- /dev/null
+++ b/test/unified-test-format/invalid/entity-client-observeTracingMessages-type.json
@@ -0,0 +1,18 @@
+{
+ "description": "entity-client-observeTracingMessages-type",
+ "schemaVersion": "1.26",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0",
+ "observeTracingMessages": "foo"
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "observeTracingMessages must be an object",
+ "operations": []
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/expectedTracingSpans-additionalProperties.json b/test/unified-test-format/invalid/expectedTracingSpans-additionalProperties.json
new file mode 100644
index 00000000..5947a286
--- /dev/null
+++ b/test/unified-test-format/invalid/expectedTracingSpans-additionalProperties.json
@@ -0,0 +1,30 @@
+{
+ "description": "expectedTracingSpans-additionalProperties",
+ "schemaVersion": "1.26",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "additional property foo not allowed in expectTracingMessages",
+ "operations": [],
+ "expectTracingMessages": {
+ "client": "client0",
+ "ignoreExtraSpans": false,
+ "spans": [
+ {
+ "name": "command",
+ "tags": {
+ "db.system": "mongodb"
+ }
+ }
+ ],
+ "foo": 0
+ }
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/expectedTracingSpans-clientType.json b/test/unified-test-format/invalid/expectedTracingSpans-clientType.json
new file mode 100644
index 00000000..2fe7faea
--- /dev/null
+++ b/test/unified-test-format/invalid/expectedTracingSpans-clientType.json
@@ -0,0 +1,28 @@
+{
+ "description": "expectedTracingSpans-clientType",
+ "schemaVersion": "1.26",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "client type must be string",
+ "operations": [],
+ "expectTracingMessages": {
+ "client": 0,
+ "spans": [
+ {
+ "name": "command",
+ "tags": {
+ "db.system": "mongodb"
+ }
+ }
+ ]
+ }
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/expectedTracingSpans-emptyNestedSpan.json b/test/unified-test-format/invalid/expectedTracingSpans-emptyNestedSpan.json
new file mode 100644
index 00000000..8a98d5ba
--- /dev/null
+++ b/test/unified-test-format/invalid/expectedTracingSpans-emptyNestedSpan.json
@@ -0,0 +1,29 @@
+{
+ "description": "expectedTracingSpans-emptyNestedSpan",
+ "schemaVersion": "1.26",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "nested spans must not have fewer than 1 items'",
+ "operations": [],
+ "expectTracingMessages": {
+ "client": "client0",
+ "spans": [
+ {
+ "name": "command",
+ "tags": {
+ "db.system": "mongodb"
+ },
+ "nested": []
+ }
+ ]
+ }
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/expectedTracingSpans-invalidNestedSpan.json b/test/unified-test-format/invalid/expectedTracingSpans-invalidNestedSpan.json
new file mode 100644
index 00000000..79a86744
--- /dev/null
+++ b/test/unified-test-format/invalid/expectedTracingSpans-invalidNestedSpan.json
@@ -0,0 +1,31 @@
+{
+ "description": "expectedTracingSpans-invalidNestedSpan",
+ "schemaVersion": "1.26",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "nested span must have required property name",
+ "operations": [],
+ "expectTracingMessages": {
+ "client": "client0",
+ "spans": [
+ {
+ "name": "command",
+ "tags": {
+ "db.system": "mongodb"
+ },
+ "nested": [
+ {}
+ ]
+ }
+ ]
+ }
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/expectedTracingSpans-missingPropertyClient.json b/test/unified-test-format/invalid/expectedTracingSpans-missingPropertyClient.json
new file mode 100644
index 00000000..2fb1cd5b
--- /dev/null
+++ b/test/unified-test-format/invalid/expectedTracingSpans-missingPropertyClient.json
@@ -0,0 +1,27 @@
+{
+ "description": "expectedTracingSpans-missingPropertyClient",
+ "schemaVersion": "1.26",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "missing required property client",
+ "operations": [],
+ "expectTracingMessages": {
+ "spans": [
+ {
+ "name": "command",
+ "tags": {
+ "db.system": "mongodb"
+ }
+ }
+ ]
+ }
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/expectedTracingSpans-missingPropertySpans.json b/test/unified-test-format/invalid/expectedTracingSpans-missingPropertySpans.json
new file mode 100644
index 00000000..acd10307
--- /dev/null
+++ b/test/unified-test-format/invalid/expectedTracingSpans-missingPropertySpans.json
@@ -0,0 +1,20 @@
+{
+ "description": "expectedTracingSpans-missingPropertySpans",
+ "schemaVersion": "1.26",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "missing required property spans",
+ "operations": [],
+ "expectTracingMessages": {
+ "client": "client0"
+ }
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedAdditionalProperties.json b/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedAdditionalProperties.json
new file mode 100644
index 00000000..17299f86
--- /dev/null
+++ b/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedAdditionalProperties.json
@@ -0,0 +1,28 @@
+{
+ "description": "expectedTracingSpans-spanMalformedAdditionalProperties",
+ "schemaVersion": "1.26",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "Span must not have additional properties",
+ "operations": [],
+ "expectTracingMessages": {
+ "client": "client0",
+ "spans": [
+ {
+ "name": "foo",
+ "tags": {},
+ "nested": [],
+ "foo": "bar"
+ }
+ ]
+ }
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedMissingName.json b/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedMissingName.json
new file mode 100644
index 00000000..0257cd9b
--- /dev/null
+++ b/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedMissingName.json
@@ -0,0 +1,27 @@
+{
+ "description": "expectedTracingSpans-spanMalformedMissingName",
+ "schemaVersion": "1.26",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "missing required span name",
+ "operations": [],
+ "expectTracingMessages": {
+ "client": "client0",
+ "spans": [
+ {
+ "tags": {
+ "db.system": "mongodb"
+ }
+ }
+ ]
+ }
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedMissingTags.json b/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedMissingTags.json
new file mode 100644
index 00000000..a09ca31c
--- /dev/null
+++ b/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedMissingTags.json
@@ -0,0 +1,25 @@
+{
+ "description": "expectedTracingSpans-spanMalformedMissingTags",
+ "schemaVersion": "1.26",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "missing required span tags",
+ "operations": [],
+ "expectTracingMessages": {
+ "client": "client0",
+ "spans": [
+ {
+ "name": "foo"
+ }
+ ]
+ }
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedNestedMustBeArray.json b/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedNestedMustBeArray.json
new file mode 100644
index 00000000..ccff0410
--- /dev/null
+++ b/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedNestedMustBeArray.json
@@ -0,0 +1,27 @@
+{
+ "description": "expectedTracingSpans-spanMalformedNestedMustBeArray",
+ "schemaVersion": "1.26",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "nested spans must be an array",
+ "operations": [],
+ "expectTracingMessages": {
+ "client": "client0",
+ "spans": [
+ {
+ "name": "foo",
+ "tags": {},
+ "nested": {}
+ }
+ ]
+ }
+ }
+ ]
+}
diff --git a/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedTagsMustBeObject.json b/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedTagsMustBeObject.json
new file mode 100644
index 00000000..72af1c29
--- /dev/null
+++ b/test/unified-test-format/invalid/expectedTracingSpans-spanMalformedTagsMustBeObject.json
@@ -0,0 +1,26 @@
+{
+ "description": "expectedTracingSpans-spanMalformedNestedMustBeObject",
+ "schemaVersion": "1.26",
+ "createEntities": [
+ {
+ "client": {
+ "id": "client0"
+ }
+ }
+ ],
+ "tests": [
+ {
+ "description": "span tags must be an object",
+ "operations": [],
+ "expectTracingMessages": {
+ "client": "client0",
+ "spans": [
+ {
+ "name": "foo",
+ "tags": []
+ }
+ ]
+ }
+ }
+ ]
+}

View File

@ -0,0 +1,26 @@
diff --git a/test/auth/legacy/connection-string.json b/test/auth/legacy/connection-string.json
index 3a099c813..8982b61d5 100644
--- a/test/auth/legacy/connection-string.json
+++ b/test/auth/legacy/connection-string.json
@@ -440,6 +440,21 @@
}
}
},
+ {
+ "description": "should throw an exception if username provided (MONGODB-AWS)",
+ "uri": "mongodb://user@localhost.com/?authMechanism=MONGODB-AWS",
+ "valid": false
+ },
+ {
+ "description": "should throw an exception if username and password provided (MONGODB-AWS)",
+ "uri": "mongodb://user:pass@localhost.com/?authMechanism=MONGODB-AWS",
+ "valid": false
+ },
+ {
+ "description": "should throw an exception if AWS_SESSION_TOKEN provided (MONGODB-AWS)",
+ "uri": "mongodb://localhost/?authMechanism=MONGODB-AWS&authMechanismProperties=AWS_SESSION_TOKEN:token",
+ "valid": false
+ },
{
"description": "should recognise the mechanism with test environment (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:test",

View File

@ -1,84 +1,11 @@
diff --git a/test/connection_logging/connection-logging.json b/test/connection_logging/connection-logging.json
index d40cfbb7e..5799e834d 100644
index 5799e834..72103b3c 100644
--- a/test/connection_logging/connection-logging.json
+++ b/test/connection_logging/connection-logging.json
@@ -272,7 +272,13 @@
"level": "debug",
"component": "connection",
"data": {
- "message": "Connection pool closed",
+ "message": "Connection closed",
+ "driverConnectionId": {
+ "$$type": [
+ "int",
+ "long"
+ ]
+ },
"serverHost": {
"$$type": "string"
},
@@ -281,20 +287,15 @@
"int",
"long"
]
- }
+ },
+ "reason": "Connection pool was closed"
}
},
{
"level": "debug",
"component": "connection",
"data": {
- "message": "Connection closed",
- "driverConnectionId": {
- "$$type": [
- "int",
- "long"
- ]
- },
+ "message": "Connection pool closed",
"serverHost": {
"$$type": "string"
},
@@ -303,8 +304,7 @@
"int",
"long"
]
- },
- "reason": "Connection pool was closed"
+ }
}
}
]
@@ -446,22 +446,6 @@
@@ -446,6 +446,22 @@
}
}
},
- {
- "level": "debug",
- "component": "connection",
- "data": {
- "message": "Connection pool cleared",
- "serverHost": {
- "$$type": "string"
- },
- "serverPort": {
- "$$type": [
- "int",
- "long"
- ]
- }
- }
- },
{
"level": "debug",
"component": "connection",
@@ -514,6 +498,22 @@
]
}
}
+ },
+ {
+ "level": "debug",
+ "component": "connection",
@ -94,6 +21,30 @@ index d40cfbb7e..5799e834d 100644
+ ]
+ }
+ }
+ },
{
"level": "debug",
"component": "connection",
@@ -498,22 +514,6 @@
]
}
}
- },
- {
- "level": "debug",
- "component": "connection",
- "data": {
- "message": "Connection pool cleared",
- "serverHost": {
- "$$type": "string"
- },
- "serverPort": {
- "$$type": [
- "int",
- "long"
- ]
- }
- }
}
]
}

View File

@ -0,0 +1,31 @@
diff --git a/test/discovery_and_monitoring/errors/error_handling_handshake.json b/test/discovery_and_monitoring/errors/error_handling_handshake.json
index 56ca7d113..bf83f46f6 100644
--- a/test/discovery_and_monitoring/errors/error_handling_handshake.json
+++ b/test/discovery_and_monitoring/errors/error_handling_handshake.json
@@ -97,14 +97,22 @@
"outcome": {
"servers": {
"a:27017": {
- "type": "Unknown",
- "topologyVersion": null,
+ "type": "RSPrimary",
+ "setName": "rs",
+ "topologyVersion": {
+ "processId": {
+ "$oid": "000000000000000000000001"
+ },
+ "counter": {
+ "$numberLong": "1"
+ }
+ },
"pool": {
- "generation": 1
+ "generation": 0
}
}
},
- "topologyType": "ReplicaSetNoPrimary",
+ "topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs"
}

View File

@ -0,0 +1,815 @@
diff --git a/test/sessions/snapshot-sessions.json b/test/sessions/snapshot-sessions.json
index 260f8b6f4..8f806ea75 100644
--- a/test/sessions/snapshot-sessions.json
+++ b/test/sessions/snapshot-sessions.json
@@ -988,6 +988,810 @@
}
}
]
+ },
+ {
+ "description": "Find operation with snapshot and snapshot time",
+ "operations": [
+ {
+ "name": "find",
+ "object": "collection0",
+ "arguments": {
+ "session": "session0",
+ "filter": {}
+ },
+ "expectResult": [
+ {
+ "_id": 1,
+ "x": 11
+ },
+ {
+ "_id": 2,
+ "x": 11
+ }
+ ]
+ },
+ {
+ "name": "getSnapshotTime",
+ "object": "session0",
+ "saveResultAsEntity": "savedSnapshotTime"
+ },
+ {
+ "name": "insertOne",
+ "object": "collection0",
+ "arguments": {
+ "document": {
+ "_id": 3,
+ "x": 33
+ }
+ }
+ },
+ {
+ "name": "createEntities",
+ "object": "testRunner",
+ "arguments": {
+ "entities": [
+ {
+ "session": {
+ "id": "session2",
+ "client": "client0",
+ "sessionOptions": {
+ "snapshot": true,
+ "snapshotTime": "savedSnapshotTime"
+ }
+ }
+ }
+ ]
+ }
+ },
+ {
+ "name": "find",
+ "object": "collection0",
+ "arguments": {
+ "session": "session2",
+ "filter": {}
+ },
+ "expectResult": [
+ {
+ "_id": 1,
+ "x": 11
+ },
+ {
+ "_id": 2,
+ "x": 11
+ }
+ ]
+ },
+ {
+ "name": "find",
+ "object": "collection0",
+ "arguments": {
+ "session": "session2",
+ "filter": {}
+ },
+ "expectResult": [
+ {
+ "_id": 1,
+ "x": 11
+ },
+ {
+ "_id": 2,
+ "x": 11
+ }
+ ]
+ },
+ {
+ "name": "find",
+ "object": "collection0",
+ "arguments": {
+ "filter": {}
+ },
+ "expectResult": [
+ {
+ "_id": 1,
+ "x": 11
+ },
+ {
+ "_id": 2,
+ "x": 11
+ },
+ {
+ "_id": 3,
+ "x": 33
+ }
+ ]
+ }
+ ],
+ "expectEvents": [
+ {
+ "client": "client0",
+ "events": [
+ {
+ "commandStartedEvent": {
+ "command": {
+ "find": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$exists": false
+ }
+ }
+ },
+ "databaseName": "database0"
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "find": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$matchesEntity": "savedSnapshotTime"
+ }
+ }
+ },
+ "databaseName": "database0"
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "find": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$matchesEntity": "savedSnapshotTime"
+ }
+ }
+ },
+ "databaseName": "database0"
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "find": "collection0",
+ "readConcern": {
+ "$$exists": false
+ }
+ },
+ "databaseName": "database0"
+ }
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "description": "Distinct operation with snapshot and snapshot time",
+ "operations": [
+ {
+ "name": "distinct",
+ "object": "collection0",
+ "arguments": {
+ "session": "session0",
+ "filter": {},
+ "fieldName": "x"
+ },
+ "expectResult": [
+ 11
+ ]
+ },
+ {
+ "name": "getSnapshotTime",
+ "object": "session0",
+ "saveResultAsEntity": "savedSnapshotTime"
+ },
+ {
+ "name": "insertOne",
+ "object": "collection0",
+ "arguments": {
+ "document": {
+ "_id": 3,
+ "x": 33
+ }
+ }
+ },
+ {
+ "name": "createEntities",
+ "object": "testRunner",
+ "arguments": {
+ "entities": [
+ {
+ "session": {
+ "id": "session2",
+ "client": "client0",
+ "sessionOptions": {
+ "snapshot": true,
+ "snapshotTime": "savedSnapshotTime"
+ }
+ }
+ }
+ ]
+ }
+ },
+ {
+ "name": "distinct",
+ "object": "collection0",
+ "arguments": {
+ "session": "session2",
+ "filter": {},
+ "fieldName": "x"
+ },
+ "expectResult": [
+ 11
+ ]
+ },
+ {
+ "name": "distinct",
+ "object": "collection0",
+ "arguments": {
+ "session": "session2",
+ "filter": {},
+ "fieldName": "x"
+ },
+ "expectResult": [
+ 11
+ ]
+ },
+ {
+ "name": "distinct",
+ "object": "collection0",
+ "arguments": {
+ "filter": {},
+ "fieldName": "x"
+ },
+ "expectResult": [
+ 11,
+ 33
+ ]
+ }
+ ],
+ "expectEvents": [
+ {
+ "client": "client0",
+ "events": [
+ {
+ "commandStartedEvent": {
+ "command": {
+ "distinct": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$exists": false
+ }
+ }
+ },
+ "databaseName": "database0"
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "distinct": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$matchesEntity": "savedSnapshotTime"
+ }
+ }
+ },
+ "databaseName": "database0"
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "distinct": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$matchesEntity": "savedSnapshotTime"
+ }
+ }
+ },
+ "databaseName": "database0"
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "distinct": "collection0",
+ "readConcern": {
+ "$$exists": false
+ }
+ },
+ "databaseName": "database0"
+ }
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "description": "Aggregate operation with snapshot and snapshot time",
+ "operations": [
+ {
+ "name": "aggregate",
+ "object": "collection0",
+ "arguments": {
+ "session": "session0",
+ "pipeline": [
+ {
+ "$match": {
+ "_id": 1
+ }
+ }
+ ]
+ },
+ "expectResult": [
+ {
+ "_id": 1,
+ "x": 11
+ }
+ ]
+ },
+ {
+ "name": "getSnapshotTime",
+ "object": "session0",
+ "saveResultAsEntity": "savedSnapshotTime"
+ },
+ {
+ "name": "findOneAndUpdate",
+ "object": "collection0",
+ "arguments": {
+ "filter": {
+ "_id": 1
+ },
+ "update": {
+ "$inc": {
+ "x": 1
+ }
+ },
+ "returnDocument": "After"
+ },
+ "expectResult": {
+ "_id": 1,
+ "x": 12
+ }
+ },
+ {
+ "name": "createEntities",
+ "object": "testRunner",
+ "arguments": {
+ "entities": [
+ {
+ "session": {
+ "id": "session2",
+ "client": "client0",
+ "sessionOptions": {
+ "snapshot": true,
+ "snapshotTime": "savedSnapshotTime"
+ }
+ }
+ }
+ ]
+ }
+ },
+ {
+ "name": "aggregate",
+ "object": "collection0",
+ "arguments": {
+ "session": "session2",
+ "pipeline": [
+ {
+ "$match": {
+ "_id": 1
+ }
+ }
+ ]
+ },
+ "expectResult": [
+ {
+ "_id": 1,
+ "x": 11
+ }
+ ]
+ },
+ {
+ "name": "aggregate",
+ "object": "collection0",
+ "arguments": {
+ "session": "session2",
+ "pipeline": [
+ {
+ "$match": {
+ "_id": 1
+ }
+ }
+ ]
+ },
+ "expectResult": [
+ {
+ "_id": 1,
+ "x": 11
+ }
+ ]
+ },
+ {
+ "name": "aggregate",
+ "object": "collection0",
+ "arguments": {
+ "pipeline": [
+ {
+ "$match": {
+ "_id": 1
+ }
+ }
+ ]
+ },
+ "expectResult": [
+ {
+ "_id": 1,
+ "x": 12
+ }
+ ]
+ }
+ ],
+ "expectEvents": [
+ {
+ "client": "client0",
+ "events": [
+ {
+ "commandStartedEvent": {
+ "command": {
+ "aggregate": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$exists": false
+ }
+ }
+ },
+ "databaseName": "database0"
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "aggregate": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$matchesEntity": "savedSnapshotTime"
+ }
+ }
+ },
+ "databaseName": "database0"
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "aggregate": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$matchesEntity": "savedSnapshotTime"
+ }
+ }
+ },
+ "databaseName": "database0"
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "aggregate": "collection0",
+ "readConcern": {
+ "$$exists": false
+ }
+ },
+ "databaseName": "database0"
+ }
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "description": "countDocuments operation with snapshot and snapshot time",
+ "operations": [
+ {
+ "name": "countDocuments",
+ "object": "collection0",
+ "arguments": {
+ "session": "session0",
+ "filter": {}
+ },
+ "expectResult": 2
+ },
+ {
+ "name": "getSnapshotTime",
+ "object": "session0",
+ "saveResultAsEntity": "savedSnapshotTime"
+ },
+ {
+ "name": "insertOne",
+ "object": "collection0",
+ "arguments": {
+ "document": {
+ "_id": 3,
+ "x": 33
+ }
+ }
+ },
+ {
+ "name": "createEntities",
+ "object": "testRunner",
+ "arguments": {
+ "entities": [
+ {
+ "session": {
+ "id": "session2",
+ "client": "client0",
+ "sessionOptions": {
+ "snapshot": true,
+ "snapshotTime": "savedSnapshotTime"
+ }
+ }
+ }
+ ]
+ }
+ },
+ {
+ "name": "countDocuments",
+ "object": "collection0",
+ "arguments": {
+ "session": "session2",
+ "filter": {}
+ },
+ "expectResult": 2
+ },
+ {
+ "name": "countDocuments",
+ "object": "collection0",
+ "arguments": {
+ "session": "session2",
+ "filter": {}
+ },
+ "expectResult": 2
+ },
+ {
+ "name": "countDocuments",
+ "object": "collection0",
+ "arguments": {
+ "filter": {}
+ },
+ "expectResult": 3
+ }
+ ],
+ "expectEvents": [
+ {
+ "client": "client0",
+ "events": [
+ {
+ "commandStartedEvent": {
+ "command": {
+ "aggregate": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$exists": false
+ }
+ }
+ },
+ "databaseName": "database0"
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "aggregate": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$matchesEntity": "savedSnapshotTime"
+ }
+ }
+ },
+ "databaseName": "database0"
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "aggregate": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$matchesEntity": "savedSnapshotTime"
+ }
+ }
+ },
+ "databaseName": "database0"
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "aggregate": "collection0",
+ "readConcern": {
+ "$$exists": false
+ }
+ },
+ "databaseName": "database0"
+ }
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "description": "Mixed operation with snapshot and snapshotTime",
+ "operations": [
+ {
+ "name": "find",
+ "object": "collection0",
+ "arguments": {
+ "session": "session0",
+ "filter": {
+ "_id": 1
+ }
+ },
+ "expectResult": [
+ {
+ "_id": 1,
+ "x": 11
+ }
+ ]
+ },
+ {
+ "name": "getSnapshotTime",
+ "object": "session0",
+ "saveResultAsEntity": "savedSnapshotTime"
+ },
+ {
+ "name": "findOneAndUpdate",
+ "object": "collection0",
+ "arguments": {
+ "filter": {
+ "_id": 1
+ },
+ "update": {
+ "$inc": {
+ "x": 1
+ }
+ },
+ "returnDocument": "After"
+ },
+ "expectResult": {
+ "_id": 1,
+ "x": 12
+ }
+ },
+ {
+ "name": "createEntities",
+ "object": "testRunner",
+ "arguments": {
+ "entities": [
+ {
+ "session": {
+ "id": "session2",
+ "client": "client0",
+ "sessionOptions": {
+ "snapshot": true,
+ "snapshotTime": "savedSnapshotTime"
+ }
+ }
+ }
+ ]
+ }
+ },
+ {
+ "name": "find",
+ "object": "collection0",
+ "arguments": {
+ "filter": {
+ "_id": 1
+ }
+ },
+ "expectResult": [
+ {
+ "_id": 1,
+ "x": 12
+ }
+ ]
+ },
+ {
+ "name": "aggregate",
+ "object": "collection0",
+ "arguments": {
+ "pipeline": [
+ {
+ "$match": {
+ "_id": 1
+ }
+ }
+ ],
+ "session": "session2"
+ },
+ "expectResult": [
+ {
+ "_id": 1,
+ "x": 11
+ }
+ ]
+ },
+ {
+ "name": "distinct",
+ "object": "collection0",
+ "arguments": {
+ "fieldName": "x",
+ "filter": {},
+ "session": "session2"
+ },
+ "expectResult": [
+ 11
+ ]
+ }
+ ],
+ "expectEvents": [
+ {
+ "client": "client0",
+ "events": [
+ {
+ "commandStartedEvent": {
+ "command": {
+ "find": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$exists": false
+ }
+ }
+ }
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "find": "collection0",
+ "readConcern": {
+ "$$exists": false
+ }
+ }
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "aggregate": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$matchesEntity": "savedSnapshotTime"
+ }
+ }
+ }
+ }
+ },
+ {
+ "commandStartedEvent": {
+ "command": {
+ "distinct": "collection0",
+ "readConcern": {
+ "level": "snapshot",
+ "atClusterTime": {
+ "$$matchesEntity": "savedSnapshotTime"
+ }
+ }
+ }
+ }
+ }
+ ]
+ }
+ ]
}
]
}

View File

@ -1,140 +0,0 @@
#!/bin/bash
# Utility functions used by pymongo evergreen scripts.
set -eu
find_python3() {
PYTHON=""
# Find a suitable toolchain version, if available.
if [ "$(uname -s)" = "Darwin" ]; then
PYTHON="/Library/Frameworks/Python.Framework/Versions/3.9/bin/python3"
elif [ "Windows_NT" = "${OS:-}" ]; then # Magic variable in cygwin
PYTHON="C:/python/Python39/python.exe"
else
# Prefer our own toolchain, fall back to mongodb toolchain if it has Python 3.9+.
if [ -f "/opt/python/3.9/bin/python3" ]; then
PYTHON="/opt/python/Current/bin/python3"
elif is_python_39 "$(command -v /opt/mongodbtoolchain/v5/bin/python3)"; then
PYTHON="/opt/mongodbtoolchain/v5/bin/python3"
elif is_python_39 "$(command -v /opt/mongodbtoolchain/v4/bin/python3)"; then
PYTHON="/opt/mongodbtoolchain/v4/bin/python3"
elif is_python_39 "$(command -v /opt/mongodbtoolchain/v3/bin/python3)"; then
PYTHON="/opt/mongodbtoolchain/v3/bin/python3"
fi
fi
# Add a fallback system python3 if it is available and Python 3.9+.
if [ -z "$PYTHON" ]; then
if is_python_39 "$(command -v python3)"; then
PYTHON="$(command -v python3)"
fi
fi
if [ -z "$PYTHON" ]; then
echo "Cannot test without python3.9+ installed!"
exit 1
fi
echo "$PYTHON"
}
# Usage:
# createvirtualenv /path/to/python /output/path/for/venv
# * param1: Python binary to use for the virtualenv
# * param2: Path to the virtualenv to create
createvirtualenv () {
PYTHON=$1
VENVPATH=$2
# Prefer venv
VENV="$PYTHON -m venv"
if [ "$(uname -s)" = "Darwin" ]; then
VIRTUALENV="$PYTHON -m virtualenv"
else
VIRTUALENV=$(command -v virtualenv 2>/dev/null || echo "$PYTHON -m virtualenv")
VIRTUALENV="$VIRTUALENV -p $PYTHON"
fi
if ! $VENV $VENVPATH 2>/dev/null; then
# Workaround for bug in older versions of virtualenv.
$VIRTUALENV $VENVPATH 2>/dev/null || $VIRTUALENV $VENVPATH
fi
if [ "Windows_NT" = "${OS:-}" ]; then
# Workaround https://bugs.python.org/issue32451:
# mongovenv/Scripts/activate: line 3: $'\r': command not found
dos2unix $VENVPATH/Scripts/activate || true
. $VENVPATH/Scripts/activate
else
. $VENVPATH/bin/activate
fi
export PIP_QUIET=1
python -m pip install --upgrade pip
}
# Usage:
# testinstall /path/to/python /path/to/.whl ["no-virtualenv"]
# * param1: Python binary to test
# * param2: Path to the wheel to install
# * param3 (optional): If set to a non-empty string, don't create a virtualenv. Used in manylinux containers.
testinstall () {
PYTHON=$1
RELEASE=$2
NO_VIRTUALENV=$3
PYTHON_IMPL=$(python -c "import platform; print(platform.python_implementation())")
if [ -z "$NO_VIRTUALENV" ]; then
createvirtualenv $PYTHON venvtestinstall
PYTHON=python
fi
$PYTHON -m pip install --upgrade $RELEASE
cd tools
if [ "$PYTHON_IMPL" = "CPython" ]; then
$PYTHON fail_if_no_c.py
fi
$PYTHON -m pip uninstall -y pymongo
cd ..
if [ -z "$NO_VIRTUALENV" ]; then
deactivate
rm -rf venvtestinstall
fi
}
# Function that returns success if the provided Python binary is version 3.9 or later
# Usage:
# is_python_39 /path/to/python
# * param1: Python binary
is_python_39() {
if [ -z "$1" ]; then
return 1
elif $1 -c "import sys; exit(sys.version_info[:2] < (3, 9))"; then
# runs when sys.version_info[:2] >= (3, 9)
return 0
else
return 1
fi
}
# Function that gets a python binary given a python version string.
# Versions can be of the form 3.xx or pypy3.xx.
get_python_binary() {
version=$1
if [ "$(uname -s)" = "Darwin" ]; then
PYTHON="/Library/Frameworks/Python.Framework/Versions/$version/bin/python3"
elif [ "Windows_NT" = "${OS:-}" ]; then
version=$(echo $version | cut -d. -f1,2 | sed 's/\.//g')
if [ -n "${IS_WIN32:-}" ]; then
PYTHON="C:/python/32/Python$version/python.exe"
else
PYTHON="C:/python/Python$version/python.exe"
fi
else
PYTHON="/opt/python/$version/bin/python3"
fi
if is_python_39 "$(command -v $PYTHON)"; then
echo "$PYTHON"
else
echo "Could not find suitable python binary for '$version'" >&2
return 1
fi
}

44
.github/copilot-instructions.md vendored Normal file
View File

@ -0,0 +1,44 @@
When reviewing code, focus on:
## Security Critical Issues
- Check for hardcoded secrets, API keys, or credentials.
- Check for instances of potential method call injection, dynamic code execution, symbol injection or other code injection vulnerabilities.
## Performance Red Flags
- Spot inefficient loops and algorithmic issues.
- Check for memory leaks and resource cleanup.
## Code Quality Essentials
- Methods should be focused and appropriately sized. If a method is doing too much, suggest refactorings to split it up.
- Use clear, descriptive naming conventions.
- Avoid encapsulation violations and ensure proper separation of concerns.
- All public classes, modules, and methods should have clear documentation in Sphinx format.
## PyMongo-specific Concerns
- Do not review files within `pymongo/synchronous` or files in `test/` that also have a file of the same name in `test/asynchronous` unless the reviewed changes include a `_IS_SYNC` statement. PyMongo generates these files from `pymongo/asynchronous` and `test/asynchronous` using `tools/synchro.py`.
- All asynchronous functions must not call any blocking I/O.
## Review Style
- Be specific and actionable in feedback.
- Explain the "why" behind recommendations.
- Acknowledge good patterns when you see them.
- Ask clarifying questions when code intent is unclear.
Always prioritize security vulnerabilities and performance issues that could impact users.
Always suggest changes to improve readability and testability. For example, this suggestion seeks to make the code more readable, reusable, and testable:
```python
# Instead of:
if user.email and "@" in user.email and len(user.email) > 5:
submit_button.enabled = True
else:
submit_button.enabled = False
# Consider:
def valid_email(email):
return email and "@" in email and len(email) > 5
submit_button.enabled = valid_email(user.email)
```

View File

@ -5,6 +5,8 @@ updates:
directory: "/"
schedule:
interval: "weekly"
cooldown:
default-days: 7
groups:
actions:
patterns:

33
.github/pull_request_template.md vendored Normal file
View File

@ -0,0 +1,33 @@
<!-- Thanks for contributing! -->
<!-- Please ensure that the title of the PR is in the following form:
[JIRA TICKET]: Issue Title
If you are an external contributor and there is no JIRA ticket associated with your change, then use your best judgement
for the PR title. A MongoDB employee will create a JIRA ticket and edit the name and links as appropriate.
Note on AI Contributions:
We do not accept pull requests that are primarily or substantially generated by AI tools (ChatGPT, Copilot, etc.).
All contributions must be written and understood by human contributors.
-->
[JIRA TICKET]
## Changes in this PR
<!-- What changes did you make to the code? What new APIs (public or private) were added, removed, or edited to generate
the desired outcome explained in the above summary? -->
## Test Plan
<!-- How did you test the code? If you added unit tests, you can say that. If you didnt introduce unit tests, explain why.
All code should be tested in some way so please list what your validation strategy was. -->
## Checklist
<!-- Do not delete the items provided on this checklist. -->
### Checklist for Author
- [ ] Did you update the changelog (if necessary)?
- [ ] Is there test coverage?
- [ ] Is any followup work tracked in a JIRA ticket? If so, add link(s).
### Checklist for Reviewer
- [ ] Does the title of the PR reference a JIRA Ticket?
- [ ] Do you fully understand the implementation? (Would you be comfortable explaining how this code works to someone else?)
- [ ] Is all relevant documentation (README or docstring) updated?

View File

@ -38,15 +38,15 @@ jobs:
build-mode: none
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
ref: ${{ inputs.ref }}
persist-credentials: false
- uses: actions/setup-python@v5
- uses: actions/setup-python@v6
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@51f77329afa6477de8c49fc9c7046c15b9a4e79d # v3
uses: github/codeql-action/init@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v4
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
@ -63,6 +63,6 @@ jobs:
pip install -e .
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@51f77329afa6477de8c49fc9c7046c15b9a4e79d # v3
uses: github/codeql-action/analyze@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v4
with:
category: "/language:${{matrix.language}}"

View File

@ -33,11 +33,11 @@ jobs:
outputs:
version: ${{ steps.pre-publish.outputs.version }}
steps:
- uses: mongodb-labs/drivers-github-tools/secure-checkout@v2
- uses: mongodb-labs/drivers-github-tools/secure-checkout@v3
with:
app_id: ${{ vars.APP_ID }}
private_key: ${{ secrets.APP_PRIVATE_KEY }}
- uses: mongodb-labs/drivers-github-tools/setup@v2
- uses: mongodb-labs/drivers-github-tools/setup@v3
with:
aws_role_arn: ${{ secrets.AWS_ROLE_ARN }}
aws_region_name: ${{ vars.AWS_REGION_NAME }}
@ -45,7 +45,7 @@ jobs:
artifactory_username: ${{ vars.ARTIFACTORY_USERNAME }}
- name: Get hatch
run: pip install hatch
- uses: mongodb-labs/drivers-github-tools/create-branch@v2
- uses: mongodb-labs/drivers-github-tools/create-branch@v3
id: create-branch
with:
branch_name: ${{ inputs.branch_name }}

View File

@ -41,26 +41,27 @@ jobs:
- [ubuntu-latest, "manylinux_i686", "cp3*-manylinux_i686"]
- [windows-2022, "win_amd6", "cp3*-win_amd64"]
- [windows-2022, "win32", "cp3*-win32"]
- [windows-11-arm, "win_arm64", "cp3*-win_arm64"]
- [macos-14, "macos", "cp*-macosx_*"]
steps:
- name: Checkout pymongo
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
ref: ${{ inputs.ref }}
- uses: actions/setup-python@v5
- uses: actions/setup-python@v6
with:
cache: 'pip'
python-version: 3.9
python-version: 3.11
cache-dependency-path: 'pyproject.toml'
allow-prereleases: true
- name: Set up QEMU
if: runner.os == 'Linux'
uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 # v3
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3
with:
# setup-qemu-action by default uses `tonistiigi/binfmt:latest` image,
# which is out of date. This causes seg faults during build.
@ -69,24 +70,16 @@ jobs:
platforms: all
- name: Install cibuildwheel
# Note: the default manylinux is manylinux2014
# Note: the default manylinux is manylinux_2_28
run: |
python -m pip install -U pip
python -m pip install "cibuildwheel>=2.20,<3"
python -m pip install "cibuildwheel>=3.2.0,<4"
- name: Build wheels
env:
CIBW_BUILD: ${{ matrix.buildplat[2] }}
run: python -m cibuildwheel --output-dir wheelhouse
- name: Build manylinux1 wheels
if: ${{ matrix.buildplat[1] == 'manylinux_x86_64' || matrix.buildplat[1] == 'manylinux_i686' }}
env:
CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
CIBW_MANYLINUX_I686_IMAGE: manylinux1
CIBW_BUILD: "cp39-${{ matrix.buildplat[1] }} cp39-${{ matrix.buildplat[1] }}"
run: python -m cibuildwheel --output-dir wheelhouse
- name: Assert all versions in wheelhouse
if: ${{ ! startsWith(matrix.buildplat[1], 'macos') }}
run: |
@ -95,10 +88,11 @@ jobs:
ls wheelhouse/*cp311*.whl
ls wheelhouse/*cp312*.whl
ls wheelhouse/*cp313*.whl
ls wheelhouse/*cp314*.whl
# Free-threading builds:
ls wheelhouse/*cp313t*.whl
ls wheelhouse/*cp314t*.whl
- uses: actions/upload-artifact@v4
- uses: actions/upload-artifact@v6
with:
name: wheel-${{ matrix.buildplat[1] }}
path: ./wheelhouse/*.whl
@ -106,18 +100,18 @@ jobs:
make_sdist:
name: Make SDist
runs-on: macos-13
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
ref: ${{ inputs.ref }}
- uses: actions/setup-python@v5
- uses: actions/setup-python@v6
with:
# Build sdist on lowest supported Python
python-version: '3.9'
python-version: "3.9"
- name: Build SDist
run: |
@ -131,7 +125,7 @@ jobs:
cd ..
python -c "from pymongo import has_c; assert has_c()"
- uses: actions/upload-artifact@v4
- uses: actions/upload-artifact@v6
with:
name: "sdist"
path: ./dist/*.tar.gz
@ -142,13 +136,13 @@ jobs:
name: Download Wheels
steps:
- name: Download all workflow run artifacts
uses: actions/download-artifact@v4
uses: actions/download-artifact@v7
- name: Flatten directory
working-directory: .
run: |
find . -mindepth 2 -type f -exec mv {} . \;
find . -type d -empty -delete
- uses: actions/upload-artifact@v4
- uses: actions/upload-artifact@v6
with:
name: all-dist-${{ github.run_id }}
path: "./*"

View File

@ -1,23 +0,0 @@
# [JIRA Ticket ID](Link to Ticket)
<!-- Please provide explicit URL link to the corresponding JIRA ticket. -->
# Summary
<!-- Please provide a high level overview of what changes have been made. -->
# Changes in this PR
<!-- Highlight any high level architecture changes if the summary doesn't already cover the scope. -->
# Test Plan
<!-- Talk through any unit tests added, and if this is a bug fix, please add repro steps in the event the fix needs to be verified. -->
# Screenshots (Optional)
<!-- Add a before and after picture to indicate changes. -->
# Callouts or Follow-up items (Optional)
<!-- Any additional info not already specified in the PR including but not limited to:
1. Potential stakeholders
2. Slack threads etc.
3. Implementation details that need additional oversight
4. Callouts on future tactics
-->

View File

@ -38,17 +38,16 @@ jobs:
outputs:
version: ${{ steps.pre-publish.outputs.version }}
steps:
- uses: mongodb-labs/drivers-github-tools/secure-checkout@v2
- uses: mongodb-labs/drivers-github-tools/secure-checkout@v3
with:
app_id: ${{ vars.APP_ID }}
private_key: ${{ secrets.APP_PRIVATE_KEY }}
- uses: mongodb-labs/drivers-github-tools/setup@v2
- uses: mongodb-labs/drivers-github-tools/setup@v3
with:
aws_role_arn: ${{ secrets.AWS_ROLE_ARN }}
aws_region_name: ${{ vars.AWS_REGION_NAME }}
aws_secret_id: ${{ secrets.AWS_SECRET_ID }}
artifactory_username: ${{ vars.ARTIFACTORY_USERNAME }}
- uses: mongodb-labs/drivers-github-tools/python/pre-publish@v2
- uses: mongodb-labs/drivers-github-tools/python/pre-publish@v3
id: pre-publish
with:
dry_run: ${{ env.DRY_RUN }}
@ -76,19 +75,19 @@ jobs:
id-token: write
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
uses: actions/download-artifact@v7
with:
name: all-dist-${{ github.run_id }}
path: dist/
- name: Publish package distributions to TestPyPI
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # release/v1
uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # release/v1
with:
repository-url: https://test.pypi.org/legacy/
skip-existing: true
attestations: ${{ env.DRY_RUN }}
- name: Publish package distributions to PyPI
if: startsWith(env.DRY_RUN, 'false')
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # release/v1
uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # release/v1
post-publish:
needs: [publish]
@ -100,17 +99,16 @@ jobs:
attestations: write
security-events: write
steps:
- uses: mongodb-labs/drivers-github-tools/secure-checkout@v2
- uses: mongodb-labs/drivers-github-tools/secure-checkout@v3
with:
app_id: ${{ vars.APP_ID }}
private_key: ${{ secrets.APP_PRIVATE_KEY }}
- uses: mongodb-labs/drivers-github-tools/setup@v2
- uses: mongodb-labs/drivers-github-tools/setup@v3
with:
aws_role_arn: ${{ secrets.AWS_ROLE_ARN }}
aws_region_name: ${{ vars.AWS_REGION_NAME }}
aws_secret_id: ${{ secrets.AWS_SECRET_ID }}
artifactory_username: ${{ vars.ARTIFACTORY_USERNAME }}
- uses: mongodb-labs/drivers-github-tools/python/post-publish@v2
- uses: mongodb-labs/drivers-github-tools/python/post-publish@v3
with:
following_version: ${{ env.FOLLOWING_VERSION }}
product_name: ${{ env.PRODUCT_NAME }}

104
.github/workflows/sbom.yml vendored Normal file
View File

@ -0,0 +1,104 @@
name: Generate SBOM
# This workflow uses cyclonedx-py and publishes an sbom.json artifact.
# It runs on manual trigger or when package files change on main branch,
# and creates a PR with the updated SBOM.
# Internal documentation: go/sbom-scope
on:
workflow_dispatch: {}
push:
branches: ['master']
paths:
- 'requirements.txt'
- 'requirements/**.txt'
- '!requirements/docs.txt'
- '!requirements/test.txt'
permissions:
contents: write
pull-requests: write
jobs:
sbom:
name: Generate SBOM and Create PR
runs-on: ubuntu-latest
concurrency:
group: sbom-${{ github.ref }}
cancel-in-progress: false
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Generate SBOM
run: |
python -m venv .venv
source .venv/bin/activate
python tools/generate_sbom_requirements.py
pip install -r sbom-requirements.txt
pip install .
pip uninstall -y pip setuptools
deactivate
python -m venv .venv-sbom
source .venv-sbom/bin/activate
pip install cyclonedx-bom==7.2.1
cyclonedx-py environment --spec-version 1.5 --output-format JSON --output-file sbom.json .venv
# Add PURL for pymongo (local package doesn't get PURL automatically)
jq '(.components[] | select(.name == "pymongo" and .purl == null)) |= (. + {purl: ("pkg:pypi/pymongo@" + .version)})' sbom.json > sbom.tmp.json && mv sbom.tmp.json sbom.json
- name: Download CycloneDX CLI
run: |
curl -L -s -o /tmp/cyclonedx "https://github.com/CycloneDX/cyclonedx-cli/releases/download/v0.29.1/cyclonedx-linux-x64"
chmod +x /tmp/cyclonedx
- name: Validate SBOM
run: /tmp/cyclonedx validate --input-file sbom.json --fail-on-errors
- name: Cleanup
if: always()
run: rm -rf .venv .venv-sbom sbom-requirements.txt
- name: Upload SBOM artifact
uses: actions/upload-artifact@v6
with:
name: sbom
path: sbom.json
if-no-files-found: error
- name: Create Pull Request
uses: peter-evans/create-pull-request@c0f553fe549906ede9cf27b5156039d195d2ece0 # v8
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: 'chore: Update SBOM after dependency changes'
branch: auto-update-sbom-${{ github.run_id }}
delete-branch: true
title: 'Automation: Update SBOM'
body: |
## Automated SBOM Update
This PR was automatically generated because dependency manifest files changed.
### Changes
- Updated `sbom.json` to reflect current dependencies
### Verification
The SBOM was generated using cyclonedx-py v7.2.1 with the current Python environment.
### Triggered by
- Commit: ${{ github.sha }}
- Workflow run: ${{ github.run_id }}
---
_This PR was created automatically by the [SBOM workflow](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }})_
labels: |
sbom
automated
dependencies

View File

@ -14,21 +14,24 @@ defaults:
run:
shell: bash -eux {0}
permissions:
contents: read
jobs:
static:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
persist-credentials: false
- name: Install just
uses: extractions/setup-just@e33e0265a09d6d736e2ee1e0eb685ef1de4669ff # v3
- name: Install uv
uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v5
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
with:
enable-cache: true
python-version: "3.9"
python-version: "3.10"
- name: Install just
run: uv tool install rust-just
- name: Install Python dependencies
run: |
just install
@ -56,16 +59,16 @@ jobs:
matrix:
# Tests currently only pass on ubuntu on GitHub Actions.
os: [ubuntu-latest]
python-version: ["3.9", "pypy-3.10", "3.13t"]
python-version: ["3.10", "pypy-3.11", "3.13t"]
mongodb-version: ["8.0"]
name: CPython ${{ matrix.python-version }}-${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v5
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -76,20 +79,51 @@ jobs:
- name: Run tests
run: uv run --extra test pytest -v
coverage:
# This enables a coverage report for a given PR, which will be augmented by
# the combined codecov report uploaded in Evergreen.
runs-on: ubuntu-latest
name: Coverage
steps:
- uses: actions/checkout@v6
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
with:
enable-cache: true
python-version: "3.10"
- id: setup-mongodb
uses: mongodb-labs/drivers-evergreen-tools@master
with:
version: "8.0"
- name: Install just
run: uv tool install rust-just
- name: Setup tests
run: COVERAGE=1 just setup-tests
- name: Run tests
run: just run-tests
- name: Generate xml report
run: uv tool run --with "coverage[toml]" coverage xml
- name: Upload test results to Codecov
uses: codecov/codecov-action@671740ac38dd9b0130fbe1cec585b89eea48d3de # v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
doctest:
runs-on: ubuntu-latest
name: DocTest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
persist-credentials: false
- name: Install just
uses: extractions/setup-just@e33e0265a09d6d736e2ee1e0eb685ef1de4669ff # v3
- name: Install uv
uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v5
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
with:
enable-cache: true
python-version: "3.9"
python-version: "3.10"
- name: Install just
run: uv tool install rust-just
- id: setup-mongodb
uses: mongodb-labs/drivers-evergreen-tools@master
with:
@ -105,16 +139,16 @@ jobs:
name: Docs Checks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v5
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
with:
enable-cache: true
python-version: "3.9"
python-version: "3.10"
- name: Install just
uses: extractions/setup-just@e33e0265a09d6d736e2ee1e0eb685ef1de4669ff # v3
run: uv tool install rust-just
- name: Install dependencies
run: just install
- name: Build docs
@ -124,16 +158,16 @@ jobs:
name: Link Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v5
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
with:
enable-cache: true
python-version: "3.9"
python-version: "3.10"
- name: Install just
uses: extractions/setup-just@e33e0265a09d6d736e2ee1e0eb685ef1de4669ff # v3
run: uv tool install rust-just
- name: Install dependencies
run: just install
- name: Build docs
@ -144,18 +178,18 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python: ["3.9", "3.11"]
python: ["3.10", "3.11"]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v5
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
with:
enable-cache: true
python-version: "${{matrix.python}}"
- name: Install just
uses: extractions/setup-just@e33e0265a09d6d736e2ee1e0eb685ef1de4669ff # v3
run: uv tool install rust-just
- name: Install dependencies
run: |
just install
@ -163,25 +197,55 @@ jobs:
run: |
just typing
integration_tests:
runs-on: ubuntu-latest
name: Integration Tests
steps:
- uses: actions/checkout@v6
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
with:
enable-cache: true
python-version: "3.10"
- name: Install just
run: uv tool install rust-just
- name: Install dependencies
run: |
just install
- id: setup-mongodb
uses: mongodb-labs/drivers-evergreen-tools@master
- name: Run tests
run: |
just integration-tests
- id: setup-mongodb-ssl
uses: mongodb-labs/drivers-evergreen-tools@master
with:
ssl: true
- name: Run tests
run: |
just integration-tests
make_sdist:
runs-on: ubuntu-latest
name: "Make an sdist"
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
persist-credentials: false
- uses: actions/setup-python@v5
- uses: actions/setup-python@v6
with:
cache: 'pip'
cache-dependency-path: 'pyproject.toml'
# Build sdist on lowest supported Python
python-version: '3.9'
python-version: "3.9"
- name: Build SDist
shell: bash
run: |
pip install build
python -m build --sdist
- uses: actions/upload-artifact@v4
- uses: actions/upload-artifact@v6
with:
name: "sdist"
path: dist/*.tar.gz
@ -193,7 +257,9 @@ jobs:
timeout-minutes: 20
steps:
- name: Download sdist
uses: actions/download-artifact@v4
uses: actions/download-artifact@v7
with:
path: sdist/
- name: Unpack SDist
shell: bash
run: |
@ -202,12 +268,12 @@ jobs:
mkdir test
tar --strip-components=1 -zxf *.tar.gz -C ./test
ls test
- uses: actions/setup-python@v5
- uses: actions/setup-python@v6
with:
cache: 'pip'
cache-dependency-path: 'sdist/test/pyproject.toml'
# Test sdist on lowest supported Python
python-version: '3.9'
python-version: "3.9"
- id: setup-mongodb
uses: mongodb-labs/drivers-evergreen-tools@master
- name: Run connect test from sdist
@ -223,50 +289,23 @@ jobs:
permissions:
contents: read
runs-on: ubuntu-latest
name: Test using minimum dependencies and supported Python
name: Test minimum dependencies and Python
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v5
uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7
with:
python-version: '3.9'
python-version: "3.9"
- id: setup-mongodb
uses: mongodb-labs/drivers-evergreen-tools@master
with:
version: "8.0"
# Async and our test_dns do not support dnspython 1.X, so we don't run async or dns tests here
- name: Run tests
shell: bash
run: |
uv venv
source .venv/bin/activate
uv pip install -e ".[test]" --resolution=lowest-direct
pytest -v test/test_srv_polling.py
test_minimum_for_async:
permissions:
contents: read
runs-on: ubuntu-latest
name: Test async's minimum dependencies and Python
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v5
with:
python-version: '3.9'
- id: setup-mongodb
uses: mongodb-labs/drivers-evergreen-tools@master
with:
version: "8.0"
# The lifetime kwarg we use in srv resolution was added to the async resolver API in dnspython 2.1.0
- name: Run tests
shell: bash
run: |
uv venv
source .venv/bin/activate
uv pip install -e ".[test]" --resolution=lowest-direct dnspython==2.1.0 --force-reinstall
uv pip install -e ".[test]" --resolution=lowest-direct --force-reinstall
pytest -v test/test_srv_polling.py test/test_dns.py test/asynchronous/test_srv_polling.py test/asynchronous/test_dns.py

View File

@ -14,8 +14,8 @@ jobs:
security-events: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Run zizmor 🌈
uses: zizmorcore/zizmor-action@383d31df2eb66a2f42db98c9654bdc73231f3e3a
uses: zizmorcore/zizmor-action@0dce2577a4760a2749d8cfb7a84b7d5585ebcb7d # v0.5.0

2
.gitignore vendored
View File

@ -41,4 +41,6 @@ test/lambda/*.json
# test results and logs
xunit-results/
coverage.xml
server.log
.coverage

View File

@ -122,3 +122,14 @@ repos:
language: python
require_serial: true
additional_dependencies: ["shrub.py>=3.10.0", "pyyaml>=6.0.2"]
- id: uv-lock
name: uv-lock
entry: uv lock
language: python
require_serial: true
files: ^(uv\.lock|pyproject\.toml|requirements.txt|requirements/.*\.txt)$
pass_filenames: false
fail_fast: true
additional_dependencies:
- "uv>=0.8.4"

View File

@ -16,7 +16,7 @@ be of interest or that has already been addressed.
## Supported Interpreters
PyMongo supports CPython 3.9+ and PyPy3.10+. Language features not
PyMongo supports CPython 3.9+ and PyPy3.9+. Language features not
supported by all interpreters can not be used.
## Style Guide
@ -194,10 +194,10 @@ the pages will re-render and the browser will automatically refresh.
- Run `just install` to set a local virtual environment, or you can manually
create a virtual environment and run `pytest` directly. If you want to use a specific
version of Python, remove the `.venv` folder and set `PYTHON_BINARY` before running `just install`.
version of Python, set `UV_PYTHON` before running `just install`.
- Ensure you have started the appropriate Mongo Server(s). You can run `just run-server` with optional args
to set up the server. All given options will be passed to
[`run-orchestration.sh`](https://github.com/mongodb-labs/drivers-evergreen-tools/blob/master/.evergreen/run-orchestration.sh). Run `$DRIVERS_TOOLS/evergreen/run-orchestration.sh -h`
[`run-mongodb.sh`](https://github.com/mongodb-labs/drivers-evergreen-tools/blob/master/.evergreen/run-mongodb.sh). Run `$DRIVERS_TOOLS/.evergreen/run-mongodb.sh start -h`
for a full list of options.
- Run `just test` or `pytest` to run all of the tests.
- Append `test/<mod_name>.py::<class_name>::<test_name>` to run
@ -205,6 +205,7 @@ the pages will re-render and the browser will automatically refresh.
and the `<class_name>` to test a full module. For example:
`just test test/test_change_stream.py::TestUnifiedChangeStreamsErrors::test_change_stream_errors_on_ElectionInProgress`.
- Use the `-k` argument to select tests by pattern.
- Run `just test-coverage` to run tests with coverage and display a report. After running tests with coverage, use `just coverage-html` to generate an HTML report in `htmlcov/index.html`.
## Running tests that require secrets, services, or other configuration
@ -335,7 +336,7 @@ Locally you can run:
- Run `just run-server`.
- Run `just setup-tests`.
- Run `UV_PYTHON=3.13t just run-tests`.
- Run `UV_PYTHON=3.14t just run-tests`.
### AWS Lambda tests
@ -355,13 +356,6 @@ Note: these tests can only be run from an Evergreen Linux host that has the Pyth
The `mode` can be `standalone` or `embedded`. For the `replica_set` version of the tests, use
`TOPOLOGY=replica_set just run-server`.
### Atlas Data Lake tests.
You must have `docker` or `podman` installed locally.
- Run `just setup-tests data_lake`.
- Run `just run-tests`.
### OCSP tests
- Export the orchestration file, e.g. `export ORCHESTRATION_FILE=rsa-basic-tls-ocsp-disableStapling.json`.
@ -389,11 +383,21 @@ If you are running one of the `no-responder` tests, omit the `run-server` step.
- Finally, you can use `just setup-tests --debug-log`.
- For evergreen patch builds, you can use `evergreen patch --param DEBUG_LOG=1` to enable debug logs for failed tests in the patch.
## Testing minimum dependencies
To run any of the test suites with minimum supported dependencies, pass `--test-min-deps` to
`just setup-tests`.
## Testing time-dependent operations
- `test.utils_shared.delay` - One can trigger an arbitrarily long-running operation on the server using this delay utility
in combination with a `$where` operation. Use this to test behaviors around timeouts or signals.
## Adding a new test suite
- If adding new tests files that should only be run for that test suite, add a pytest marker to the file and add
to the list of pytest markers in `pyproject.toml`. Then add the test suite to the `TEST_SUITE_MAP` in `.evergreen/scripts/utils.py`. If for some reason it is not a pytest-runnable test, add it to the list of `EXTRA_TESTS` instead.
- If the test uses Atlas or otherwise doesn't use `run-orchestration.sh`, add it to the `NO_RUN_ORCHESTRATION` list in
- If the test uses Atlas or otherwise doesn't use `run-mongodb.sh`, add it to the `NO_RUN_ORCHESTRATION` list in
`.evergreen/scripts/utils.py`.
- If there is something special required to run the local server or there is an extra flag that should always be set
like `AUTH`, add that logic to `.evergreen/scripts/run_server.py`.
@ -401,6 +405,16 @@ If you are running one of the `no-responder` tests, omit the `run-server` step.
- If there are any special test considerations, including not running `pytest` at all, handle it in `.evergreen/scripts/run_tests.py`.
- If there are any services or atlas clusters to teardown, handle them in `.evergreen/scripts/teardown_tests.py`.
- Add functions to generate the test variant(s) and task(s) to the `.evergreen/scripts/generate_config.py`.
- There are some considerations about the Python version used in the test:
- If a specific version of Python is needed in a task that is running on variants with a toolchain, use
``TOOLCHAIN_VERSION`` (e.g. `TOOLCHAIN_VERSION=3.10`). The actual path lookup needs to be done on the host, since
tasks are host-agnostic.
- If a specific Python binary is needed (for example on the FIPS host), set `UV_PYTHON=/path/to/python`.
- If a specific Python version is needed and the toolchain will not be available, use `UV_PYTHON` (e.g. `UV_PYTHON=3.11`).
- The default if neither ``TOOLCHAIN_VERSION`` or ``UV_PYTHON`` is set is to use UV to install the minimum
supported version of Python and use that. This ensures a consistent behavior across host types that do not
have the Python toolchain (e.g. Azure VMs), by having a known version of Python with the build headers (`Python.h`)
needed to build the C extensions.
- Regenerate the test variants and tasks using `pre-commit run --all-files generate-config`.
- Make sure to add instructions for running the test suite to `CONTRIBUTING.md`.
@ -413,6 +427,14 @@ a use the ticket number as the "reason" parameter to the decorator, e.g. `@flaky
When running tests locally (not in CI), the `flaky` decorator will be disabled unless `ENABLE_FLAKY` is set.
To disable the `flaky` decorator in CI, you can use `evergreen patch --param DISABLE_FLAKY=1`.
## Integration Tests
The `integration_tests` directory has a set of scripts that verify the usage of PyMongo with downstream packages or frameworks. See the [README](./integration_tests/README.md) for more information.
To run the tests, use `just integration_tests`.
The tests should be able to run with and without SSL enabled.
## Specification Tests
The MongoDB [specifications repository](https://github.com/mongodb/specifications)
@ -466,6 +488,7 @@ results into the patch file.
For example: the imaginary, unimplemented PYTHON-1234 ticket has associated spec test changes. To add those changes to `PYTHON-1234.patch`), do the following:
```bash
git diff HEAD~1 path/to/file >> .evergreen/spec-patch/PYTHON-1234.patch
```
#### Running Locally
Both `resync-all-specs.sh` and `resync-all-specs.py` can be run locally (and won't generate a PR).
@ -503,8 +526,10 @@ Use this generated file as a starting point for the completed conversion.
The script is used like so: `python tools/convert_test_to_async.py [test_file.py]`
## Generating a flame graph using py-spy
## CPU profiling
To profile a test script and generate a flame graph, follow these steps:
1. Install `py-spy` if you haven't already:
```bash
pip install py-spy
@ -514,3 +539,31 @@ To profile a test script and generate a flame graph, follow these steps:
(Note: on macOS you will need to run this command using `sudo` to allow `py-spy` to attach to the Python process.)
4. If you need to include native code (for example the C extensions), profiling should be done on a Linux system, as macOS and Windows do not support the `--native` option of `py-spy`.
Creating an ubuntu Evergreen spawn host and using `scp` to copy the flamegraph `.svg` file back to your local machine is the best way to do this.
5. You can then view the flamegraph using an SVG viewer like a browser.
## Memory profiling
To test for a memory leak or any memory-related issues, the current best tool is [memray](https://bloomberg.github.io/memray/overview.html).
In order to include code from our C extensions, it must be run in native mode, on Linux.
To do so, either spin up an Ubuntu docker container or an Ubuntu Evergreen spawn host.
From the spawn host or Ubuntu image, do the following:
1. Install `memray` if you haven't already:
```bash
pip install memray
```
2. Inside your test script, perform any required setup and then loop over the code you want to profile for improved sampling.
3. Run memray with the script under test with the `--native` flag, e.g. `python -m memray run --native -o test.bin <path/to/script>`.
4. Generate the flamegraph with `python -m memray flamegraph -o test.html test.bin`.
See the [docs](https://bloomberg.github.io/memray/flamegraph.html) for more options.
5. Then, from the host computer, use either scp or docker cp to copy the flamegraph, e.g. `scp ubuntu@ec2-3-82-52-49.compute-1.amazonaws.com:/home/ubuntu/test.html .`.
6. You can then view the flamegraph html in a browser.
## Dependabot updates
Dependabot will raise PRs at most once per week, grouped by GitHub Actions updates and Python requirement
file updates. We have a pre-commit hook that will update the `uv.lock` file when requirements change.
To update the lock file on a failing PR, you can use a method like `gh pr checkout <pr number>`, then run
`just lint uv-lock` to update the lock file, and then push the changes. If a typing dependency has changed,
also run `just typing` and handle any new findings.

View File

@ -97,7 +97,7 @@ package that is incompatible with PyMongo.
## Dependencies
PyMongo supports CPython 3.9+ and PyPy3.10+.
PyMongo supports CPython 3.9+ and PyPy3.9+.
Required dependencies:
@ -139,7 +139,8 @@ python -m pip install "pymongo[snappy]"
```
Wire protocol compression with zstandard requires
[zstandard](https://pypi.org/project/zstandard):
[backports.zstd](https://pypi.org/project/backports.zstd)
when used with Python versions before 3.14:
```bash
python -m pip install "pymongo[zstd]"

View File

@ -1009,7 +1009,7 @@ def _dict_to_bson(
try:
elements.append(_element_to_bson(key, value, check_keys, opts))
except InvalidDocument as err:
raise InvalidDocument(f"Invalid document {doc} | {err}") from err
raise InvalidDocument(f"Invalid document: {err}", doc) from err
except AttributeError:
raise TypeError(f"encoder expected a mapping type but got: {doc!r}") from None
@ -1109,7 +1109,9 @@ def _decode_all(data: _ReadableBuffer, opts: CodecOptions[_DocumentType]) -> lis
while position < end:
obj_size = _UNPACK_INT_FROM(data, position)[0]
if data_len - position < obj_size:
raise InvalidBSON("invalid object size")
raise InvalidBSON(
f"invalid object size: expected {obj_size}, got {data_len - position}"
)
obj_end = position + obj_size - 1
if data[obj_end] != 0:
raise InvalidBSON("bad eoo")
@ -1327,7 +1329,7 @@ def decode_iter(
elements = data[position : position + obj_size]
position += obj_size
yield _bson_to_dict(elements, opts) # type:ignore[misc]
yield _bson_to_dict(elements, opts)
@overload
@ -1373,7 +1375,7 @@ def decode_file_iter(
raise InvalidBSON("cut off in middle of objsize")
obj_size = _UNPACK_INT_FROM(size_data, 0)[0] - 4
elements = size_data + file_obj.read(max(0, obj_size))
yield _bson_to_dict(elements, opts) # type:ignore[arg-type, misc]
yield _bson_to_dict(elements, opts) # type:ignore[misc]
def is_valid(bson: bytes) -> bool:

View File

@ -356,7 +356,8 @@ static PyObject* datetime_ms_from_millis(PyObject* self, long long millis){
if (!(ll_millis = PyLong_FromLongLong(millis))){
return NULL;
}
dt = PyObject_CallFunctionObjArgs(state->DatetimeMS, ll_millis, NULL);
PyObject* args[1] = {ll_millis};
dt = PyObject_Vectorcall(state->DatetimeMS, args, 1, NULL);
Py_DECREF(ll_millis);
return dt;
}
@ -401,7 +402,9 @@ static PyObject* decode_datetime(PyObject* self, long long millis, const codec_o
int64_t min_millis_offset = 0;
int64_t max_millis_offset = 0;
if (options->tz_aware && options->tzinfo && options->tzinfo != Py_None) {
PyObject* utcoffset = PyObject_CallMethodObjArgs(options->tzinfo, state->_utcoffset_str, state->min_datetime, NULL);
PyObject* utcoffset_args[2] = {options->tzinfo, state->min_datetime};
PyObject* utcoffset = PyObject_VectorcallMethod(
state->_utcoffset_str, utcoffset_args, 2, NULL);
if (utcoffset == NULL) {
return 0;
}
@ -420,7 +423,9 @@ static PyObject* decode_datetime(PyObject* self, long long millis, const codec_o
(PyDateTime_DELTA_GET_MICROSECONDS(utcoffset) / 1000);
}
Py_DECREF(utcoffset);
utcoffset = PyObject_CallMethodObjArgs(options->tzinfo, state->_utcoffset_str, state->max_datetime, NULL);
utcoffset_args[1] = state->max_datetime;
utcoffset = PyObject_VectorcallMethod(
state->_utcoffset_str, utcoffset_args, 2, NULL);
if (utcoffset == NULL) {
return 0;
}
@ -481,7 +486,9 @@ static PyObject* decode_datetime(PyObject* self, long long millis, const codec_o
/* convert to local time */
if (options->tzinfo != Py_None) {
PyObject* temp = PyObject_CallMethodObjArgs(value, state->_astimezone_str, options->tzinfo, NULL);
PyObject* astimezone_args[2] = {value, options->tzinfo};
PyObject* temp = PyObject_VectorcallMethod(
state->_astimezone_str, astimezone_args, 2, NULL);
Py_DECREF(value);
value = temp;
}
@ -688,7 +695,8 @@ static int _load_python_objects(PyObject* module) {
return 1;
}
compiled = PyObject_CallFunction(re_compile, "O", empty_string);
PyObject* compile_args[1] = {empty_string};
compiled = PyObject_Vectorcall(re_compile, compile_args, 1, NULL);
Py_DECREF(re_compile);
if (compiled == NULL) {
state->REType = NULL;
@ -711,13 +719,19 @@ static long _type_marker(PyObject* object, PyObject* _type_marker_str) {
PyObject* type_marker = NULL;
long type = 0;
if (PyObject_HasAttr(object, _type_marker_str)) {
type_marker = PyObject_GetAttr(object, _type_marker_str);
if (type_marker == NULL) {
#if PY_VERSION_HEX >= 0x030D0000
// 3.13
if (PyObject_GetOptionalAttr(object, _type_marker_str, &type_marker) == -1) {
return -1;
}
}
# else
if (PyObject_HasAttr(object, _type_marker_str)) {
type_marker = PyObject_GetAttr(object, _type_marker_str);
if (type_marker == NULL) {
return -1;
}
}
#endif
/*
* Python objects with broken __getattr__ implementations could return
* arbitrary types for a call to PyObject_GetAttrString. For example
@ -814,6 +828,7 @@ int convert_codec_options(PyObject* self, PyObject* options_obj, codec_options_t
}
options->is_raw_bson = (101 == type_marker);
options->is_dict_class = (options->document_class == (PyObject*)&PyDict_Type);
options->options_obj = options_obj;
Py_INCREF(options->options_obj);
@ -1013,10 +1028,20 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
}
/*
* Use _type_marker attribute instead of PyObject_IsInstance for better perf.
*
* Skip _type_marker lookup for common built-in types
* that we know don't have a _type_marker attribute. This avoids the overhead
* of PyObject_HasAttr/PyObject_GetAttr calls for the most common cases.
*/
type = _type_marker(value, state->_type_marker_str);
if (type < 0) {
return 0;
if (PyUnicode_CheckExact(value) || PyLong_CheckExact(value) || PyFloat_CheckExact(value) ||
PyBool_Check(value) || PyDict_CheckExact(value) || PyList_CheckExact(value) ||
PyTuple_CheckExact(value) || PyBytes_CheckExact(value) || value == Py_None) {
type = 0;
} else {
type = _type_marker(value, state->_type_marker_str);
if (type < 0) {
return 0;
}
}
switch (type) {
@ -1227,7 +1252,9 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
case 100:
{
/* DBRef */
PyObject* as_doc = PyObject_CallMethodObjArgs(value, state->_as_doc_str, NULL);
PyObject* as_doc_args[1] = {value};
PyObject* as_doc = PyObject_VectorcallMethod(
state->_as_doc_str, as_doc_args, 1, NULL);
if (!as_doc) {
return 0;
}
@ -1383,7 +1410,9 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
return write_unicode(buffer, value);
} else if (PyDateTime_Check(value)) {
long long millis;
PyObject* utcoffset = PyObject_CallMethodObjArgs(value, state->_utcoffset_str , NULL);
PyObject* utcoffset_args[1] = {value};
PyObject* utcoffset = PyObject_VectorcallMethod(
state->_utcoffset_str, utcoffset_args, 1, NULL);
if (utcoffset == NULL)
return 0;
if (utcoffset != Py_None) {
@ -1422,7 +1451,9 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
if (!(uuid_rep_obj = PyLong_FromLong(options->uuid_rep))) {
return 0;
}
binary_value = PyObject_CallMethodObjArgs(state->Binary, state->_from_uuid_str, value, uuid_rep_obj, NULL);
PyObject* from_uuid_args[3] = {state->Binary, value, uuid_rep_obj};
binary_value = PyObject_VectorcallMethod(
state->_from_uuid_str, from_uuid_args, 3, NULL);
Py_DECREF(uuid_rep_obj);
if (binary_value == NULL) {
@ -1452,7 +1483,8 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
if (converter != NULL) {
/* Transform types that have a registered converter.
* A new reference is created upon transformation. */
new_value = PyObject_CallFunctionObjArgs(converter, value, NULL);
PyObject* converter_args[1] = {value};
new_value = PyObject_Vectorcall(converter, converter_args, 1, NULL);
if (new_value == NULL) {
return 0;
}
@ -1466,8 +1498,9 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
/* Try the fallback encoder if one is provided and we have not already
* attempted to use the fallback encoder. */
if (!in_fallback_call && options->type_registry.has_fallback_encoder) {
new_value = PyObject_CallFunctionObjArgs(
options->type_registry.fallback_encoder, value, NULL);
PyObject* fallback_args[1] = {value};
new_value = PyObject_Vectorcall(
options->type_registry.fallback_encoder, fallback_args, 1, NULL);
if (new_value == NULL) {
// propagate any exception raised by the callback
return 0;
@ -1645,11 +1678,11 @@ fail:
}
/* Update Invalid Document error message to include doc.
/* Update Invalid Document error to include doc as a property.
*/
void handle_invalid_doc_error(PyObject* dict) {
PyObject *etype = NULL, *evalue = NULL, *etrace = NULL;
PyObject *msg = NULL, *dict_str = NULL, *new_msg = NULL;
PyObject *msg = NULL, *new_msg = NULL, *new_evalue = NULL;
PyErr_Fetch(&etype, &evalue, &etrace);
PyObject *InvalidDocument = _error("InvalidDocument");
if (InvalidDocument == NULL) {
@ -1657,30 +1690,29 @@ void handle_invalid_doc_error(PyObject* dict) {
}
if (evalue && PyErr_GivenExceptionMatches(etype, InvalidDocument)) {
PyObject *msg = PyObject_Str(evalue);
msg = PyObject_Str(evalue);
if (msg) {
// Prepend doc to the existing message
PyObject *dict_str = PyObject_Str(dict);
if (dict_str == NULL) {
goto cleanup;
}
const char * dict_str_utf8 = PyUnicode_AsUTF8(dict_str);
if (dict_str_utf8 == NULL) {
goto cleanup;
}
const char * msg_utf8 = PyUnicode_AsUTF8(msg);
if (msg_utf8 == NULL) {
goto cleanup;
}
PyObject *new_msg = PyUnicode_FromFormat("Invalid document %s | %s", dict_str_utf8, msg_utf8);
new_msg = PyUnicode_FromFormat("Invalid document: %s", msg_utf8);
if (new_msg == NULL) {
goto cleanup;
}
// Add doc to the error instance as a property.
PyObject* exc_args[2] = {new_msg, dict};
new_evalue = PyObject_Vectorcall(InvalidDocument, exc_args, 2, NULL);
Py_DECREF(evalue);
Py_DECREF(etype);
etype = InvalidDocument;
InvalidDocument = NULL;
if (new_msg) {
evalue = new_msg;
if (new_evalue) {
evalue = new_evalue;
new_evalue = NULL;
} else {
evalue = msg;
msg = NULL;
}
}
PyErr_NormalizeException(&etype, &evalue, &etrace);
@ -1689,7 +1721,7 @@ cleanup:
PyErr_Restore(etype, evalue, etrace);
Py_XDECREF(msg);
Py_XDECREF(InvalidDocument);
Py_XDECREF(dict_str);
Py_XDECREF(new_evalue);
Py_XDECREF(new_msg);
}
@ -1946,7 +1978,8 @@ static PyObject *_dbref_hook(PyObject* self, PyObject* value) {
PyMapping_DelItem(value, state->_dollar_db_str);
}
ret = PyObject_CallFunctionObjArgs(state->DBRef, ref, id, database, value, NULL);
PyObject* dbref_args[4] = {ref, id, database, value};
ret = PyObject_Vectorcall(state->DBRef, dbref_args, 4, NULL);
Py_DECREF(value);
} else {
ret = value;
@ -2162,7 +2195,13 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
goto uuiderror;
}
binary_value = PyObject_CallFunction(state->Binary, "(Oi)", data, subtype);
PyObject* subtype_obj = PyLong_FromLong(subtype);
if (!subtype_obj) {
goto uuiderror;
}
PyObject* binary_args[2] = {data, subtype_obj};
binary_value = PyObject_Vectorcall(state->Binary, binary_args, 2, NULL);
Py_DECREF(subtype_obj);
if (binary_value == NULL) {
goto uuiderror;
}
@ -2177,7 +2216,9 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
if (!uuid_rep_obj) {
goto uuiderror;
}
value = PyObject_CallMethodObjArgs(binary_value, state->_as_uuid_str, uuid_rep_obj, NULL);
PyObject* as_uuid_args[2] = {binary_value, uuid_rep_obj};
value = PyObject_VectorcallMethod(
state->_as_uuid_str, as_uuid_args, 2, NULL);
Py_DECREF(uuid_rep_obj);
}
@ -2196,7 +2237,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
Py_DECREF(data);
goto invalid;
}
value = PyObject_CallFunctionObjArgs(state->Binary, data, st, NULL);
PyObject* binary_args[2] = {data, st};
value = PyObject_Vectorcall(state->Binary, binary_args, 2, NULL);
Py_DECREF(st);
Py_DECREF(data);
if (!value) {
@ -2217,7 +2259,13 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
if (max < 12) {
goto invalid;
}
value = PyObject_CallFunction(state->ObjectId, "y#", buffer + *position, (Py_ssize_t)12);
PyObject* oid_bytes = PyBytes_FromStringAndSize(buffer + *position, 12);
if (!oid_bytes) {
goto invalid;
}
PyObject* oid_args[1] = {oid_bytes};
value = PyObject_Vectorcall(state->ObjectId, oid_args, 1, NULL);
Py_DECREF(oid_bytes);
*position += 12;
break;
}
@ -2296,7 +2344,14 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
}
*position += (unsigned)flags_length + 1;
value = PyObject_CallFunction(state->Regex, "Oi", pattern, flags);
PyObject* flags_obj = PyLong_FromLong(flags);
if (!flags_obj) {
Py_DECREF(pattern);
goto invalid;
}
PyObject* regex_args[2] = {pattern, flags_obj};
value = PyObject_Vectorcall(state->Regex, regex_args, 2, NULL);
Py_DECREF(flags_obj);
Py_DECREF(pattern);
break;
}
@ -2329,13 +2384,21 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
}
*position += coll_length;
id = PyObject_CallFunction(state->ObjectId, "y#", buffer + *position, (Py_ssize_t)12);
PyObject* oid_bytes = PyBytes_FromStringAndSize(buffer + *position, 12);
if (!oid_bytes) {
Py_DECREF(collection);
goto invalid;
}
PyObject* oid_args[1] = {oid_bytes};
id = PyObject_Vectorcall(state->ObjectId, oid_args, 1, NULL);
Py_DECREF(oid_bytes);
if (!id) {
Py_DECREF(collection);
goto invalid;
}
*position += 12;
value = PyObject_CallFunctionObjArgs(state->DBRef, collection, id, NULL);
PyObject* dbref_args[2] = {collection, id};
value = PyObject_Vectorcall(state->DBRef, dbref_args, 2, NULL);
Py_DECREF(collection);
Py_DECREF(id);
break;
@ -2365,7 +2428,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
goto invalid;
}
*position += value_length;
value = PyObject_CallFunctionObjArgs(state->Code, code, NULL, NULL);
PyObject* code_args[1] = {code};
value = PyObject_Vectorcall(state->Code, code_args, 1, NULL);
Py_DECREF(code);
break;
}
@ -2431,7 +2495,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
}
*position += scope_size;
value = PyObject_CallFunctionObjArgs(state->Code, code, scope, NULL);
PyObject* code_scope_args[2] = {code, scope};
value = PyObject_Vectorcall(state->Code, code_scope_args, 2, NULL);
Py_DECREF(code);
Py_DECREF(scope);
break;
@ -2461,7 +2526,19 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
memcpy(&time, buffer + *position + 4, 4);
inc = BSON_UINT32_FROM_LE(inc);
time = BSON_UINT32_FROM_LE(time);
value = PyObject_CallFunction(state->Timestamp, "II", time, inc);
PyObject* time_obj = PyLong_FromUnsignedLong(time);
if (!time_obj) {
goto invalid;
}
PyObject* inc_obj = PyLong_FromUnsignedLong(inc);
if (!inc_obj) {
Py_DECREF(time_obj);
goto invalid;
}
PyObject* ts_args[2] = {time_obj, inc_obj};
value = PyObject_Vectorcall(state->Timestamp, ts_args, 2, NULL);
Py_DECREF(time_obj);
Py_DECREF(inc_obj);
*position += 8;
break;
}
@ -2473,7 +2550,13 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
}
memcpy(&ll, buffer + *position, 8);
ll = (int64_t)BSON_UINT64_FROM_LE(ll);
value = PyObject_CallFunction(state->BSONInt64, "L", ll);
PyObject* ll_obj = PyLong_FromLongLong(ll);
if (!ll_obj) {
goto invalid;
}
PyObject* int64_args[1] = {ll_obj};
value = PyObject_Vectorcall(state->BSONInt64, int64_args, 1, NULL);
Py_DECREF(ll_obj);
*position += 8;
break;
}
@ -2486,19 +2569,21 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
if (!_bytes_obj) {
goto invalid;
}
value = PyObject_CallMethodObjArgs(state->Decimal128, state->_from_bid_str, _bytes_obj, NULL);
PyObject* dec128_args[2] = {state->Decimal128, _bytes_obj};
value = PyObject_VectorcallMethod(
state->_from_bid_str, dec128_args, 2, NULL);
Py_DECREF(_bytes_obj);
*position += 16;
break;
}
case 255:
{
value = PyObject_CallFunctionObjArgs(state->MinKey, NULL);
value = PyObject_Vectorcall(state->MinKey, NULL, 0, NULL);
break;
}
case 127:
{
value = PyObject_CallFunctionObjArgs(state->MaxKey, NULL);
value = PyObject_Vectorcall(state->MaxKey, NULL, 0, NULL);
break;
}
default:
@ -2550,7 +2635,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
}
converter = PyDict_GetItem(options->type_registry.decoder_map, value_type);
if (converter != NULL) {
PyObject* new_value = PyObject_CallFunctionObjArgs(converter, value, NULL);
PyObject* converter_args[1] = {value};
PyObject* new_value = PyObject_Vectorcall(converter, converter_args, 1, NULL);
Py_DECREF(value_type);
Py_DECREF(value);
return new_value;
@ -2718,11 +2804,20 @@ static PyObject* _elements_to_dict(PyObject* self, const char* string,
unsigned max,
const codec_options_t* options) {
unsigned position = 0;
PyObject* dict = PyObject_CallObject(options->document_class, NULL);
PyObject* dict;
int raw_array = 0;
/* Use PyDict_New() directly when document_class is dict.
* This avoids the overhead of PyObject_CallObject() for the common case. */
if (options->is_dict_class) {
dict = PyDict_New();
} else {
dict = PyObject_CallObject(options->document_class, NULL);
}
if (!dict) {
return NULL;
}
int raw_array = 0;
while (position < max) {
PyObject* name = NULL;
PyObject* value = NULL;
@ -2737,7 +2832,24 @@ static PyObject* _elements_to_dict(PyObject* self, const char* string,
position = (unsigned)new_position;
}
PyObject_SetItem(dict, name, value);
/* Use PyDict_SetItem() when document_class is dict.
* PyDict_SetItem() is faster than PyObject_SetItem() because it
* avoids method lookup overhead. */
if (options->is_dict_class) {
if (PyDict_SetItem(dict, name, value) < 0) {
Py_DECREF(name);
Py_DECREF(value);
Py_DECREF(dict);
return NULL;
}
} else {
if (PyObject_SetItem(dict, name, value) < 0) {
Py_DECREF(name);
Py_DECREF(value);
Py_DECREF(dict);
return NULL;
}
}
Py_DECREF(name);
Py_DECREF(value);
}
@ -2749,9 +2861,14 @@ static PyObject* elements_to_dict(PyObject* self, const char* string,
const codec_options_t* options) {
PyObject* result;
if (options->is_raw_bson) {
return PyObject_CallFunction(
options->document_class, "y#O",
string, max, options->options_obj);
PyObject* bson_bytes = PyBytes_FromStringAndSize(string, max);
if (!bson_bytes) {
return NULL;
}
PyObject* raw_args[2] = {bson_bytes, options->options_obj};
result = PyObject_Vectorcall(options->document_class, raw_args, 2, NULL);
Py_DECREF(bson_bytes);
return result;
}
if (Py_EnterRecursiveCall(" while decoding a BSON document"))
return NULL;

View File

@ -72,6 +72,7 @@ typedef struct codec_options_t {
unsigned char datetime_conversion;
PyObject* options_obj;
unsigned char is_raw_bson;
unsigned char is_dict_class;
} codec_options_t;
/* C API functions */

View File

@ -65,6 +65,9 @@ if TYPE_CHECKING:
from array import array as _array
from mmap import mmap as _mmap
import numpy as np
import numpy.typing as npt
class UuidRepresentation:
UNSPECIFIED = 0
@ -234,13 +237,20 @@ class BinaryVector:
__slots__ = ("data", "dtype", "padding")
def __init__(self, data: Sequence[float | int], dtype: BinaryVectorDtype, padding: int = 0):
def __init__(
self,
data: Union[Sequence[float | int], npt.NDArray[np.number]],
dtype: BinaryVectorDtype,
padding: int = 0,
):
"""
:param data: Sequence of numbers representing the mathematical vector.
:param dtype: The data type stored in binary
:param padding: The number of bits in the final byte that are to be ignored
when a vector element's size is less than a byte
and the length of the vector is not a multiple of 8.
(Padding is equivalent to a negative value of `count` in
`numpy.unpackbits <https://numpy.org/doc/stable/reference/generated/numpy.unpackbits.html>`_)
"""
self.data = data
self.dtype = dtype
@ -298,7 +308,7 @@ class Binary(bytes):
def __new__(
cls: Type[Binary],
data: Union[memoryview, bytes, _mmap, _array[Any]],
data: Union[memoryview, bytes, bytearray, _mmap, _array[Any]],
subtype: int = BINARY_SUBTYPE,
) -> Binary:
if not isinstance(subtype, int):
@ -425,9 +435,19 @@ class Binary(bytes):
...
@classmethod
@overload
def from_vector(
cls: Type[Binary],
vector: Union[BinaryVector, list[int], list[float]],
vector: npt.NDArray[np.number],
dtype: BinaryVectorDtype,
padding: int = 0,
) -> Binary:
...
@classmethod
def from_vector(
cls: Type[Binary],
vector: Union[BinaryVector, list[int], list[float], npt.NDArray[np.number]],
dtype: Optional[BinaryVectorDtype] = None,
padding: Optional[int] = None,
) -> Binary:
@ -459,34 +479,72 @@ class Binary(bytes):
vector = vector.data # type: ignore
padding = 0 if padding is None else padding
if dtype == BinaryVectorDtype.INT8: # pack ints in [-128, 127] as signed int8
format_str = "b"
if padding:
raise ValueError(f"padding does not apply to {dtype=}")
elif dtype == BinaryVectorDtype.PACKED_BIT: # pack ints in [0, 255] as unsigned uint8
format_str = "B"
if 0 <= padding > 7:
raise ValueError(f"{padding=}. It must be in [0,1, ..7].")
if padding and not vector:
raise ValueError("Empty vector with non-zero padding.")
elif dtype == BinaryVectorDtype.FLOAT32: # pack floats as float32
format_str = "f"
if padding:
raise ValueError(f"padding does not apply to {dtype=}")
else:
raise NotImplementedError("%s not yet supported" % dtype)
if not isinstance(dtype, BinaryVectorDtype):
raise TypeError(
"dtype must be a bson.BinaryVectorDtype of BinaryVectorDType.INT8, PACKED_BIT, FLOAT32"
)
metadata = struct.pack("<sB", dtype.value, padding)
data = struct.pack(f"<{len(vector)}{format_str}", *vector) # type: ignore
if isinstance(vector, list):
if dtype == BinaryVectorDtype.INT8: # pack ints in [-128, 127] as signed int8
format_str = "b"
if padding:
raise ValueError(f"padding does not apply to {dtype=}")
elif dtype == BinaryVectorDtype.PACKED_BIT: # pack ints in [0, 255] as unsigned uint8
format_str = "B"
if 0 <= padding > 7:
raise ValueError(f"{padding=}. It must be in [0,1, ..7].")
if padding and not vector:
raise ValueError("Empty vector with non-zero padding.")
elif dtype == BinaryVectorDtype.FLOAT32: # pack floats as float32
format_str = "f"
if padding:
raise ValueError(f"padding does not apply to {dtype=}")
else:
raise NotImplementedError("%s not yet supported" % dtype)
data = struct.pack(f"<{len(vector)}{format_str}", *vector)
else: # vector is numpy array or incorrect type.
try:
import numpy as np
except ImportError as exc:
raise ImportError(
"Failed to create binary from vector. Check type. If numpy array, numpy must be installed."
) from exc
if not isinstance(vector, np.ndarray):
raise TypeError(
"Could not create Binary. Vector must be a BinaryVector, list[int], list[float] or numpy ndarray."
)
if vector.ndim != 1:
raise ValueError(
"from_numpy_vector only supports 1D arrays as it creates a single vector."
)
if dtype == BinaryVectorDtype.FLOAT32:
vector = vector.astype(np.dtype("float32"), copy=False)
elif dtype == BinaryVectorDtype.INT8:
if vector.min() >= -128 and vector.max() <= 127:
vector = vector.astype(np.dtype("int8"), copy=False)
else:
raise ValueError("Values found outside INT8 range.")
elif dtype == BinaryVectorDtype.PACKED_BIT:
if vector.min() >= 0 and vector.max() <= 127:
vector = vector.astype(np.dtype("uint8"), copy=False)
else:
raise ValueError("Values found outside UINT8 range.")
else:
raise NotImplementedError("%s not yet supported" % dtype)
data = vector.tobytes()
if padding and len(vector) and not (data[-1] & ((1 << padding) - 1)) == 0:
raise ValueError(
"Vector has a padding P, but bits in the final byte lower than P are non-zero. They must be zero."
)
return cls(metadata + data, subtype=VECTOR_SUBTYPE)
def as_vector(self) -> BinaryVector:
"""From the Binary, create a list of numbers, along with dtype and padding.
def as_vector(self, return_numpy: bool = False) -> BinaryVector:
"""From the Binary, create a list or 1-d numpy array of numbers, along with dtype and padding.
:param return_numpy: If True, BinaryVector.data will be a one-dimensional numpy array. By default, it is a list.
:return: BinaryVector
.. versionadded:: 4.10
@ -495,54 +553,84 @@ class Binary(bytes):
if self.subtype != VECTOR_SUBTYPE:
raise ValueError(f"Cannot decode subtype {self.subtype} as a vector")
position = 0
dtype, padding = struct.unpack_from("<sB", self, position)
position += 2
dtype, padding = struct.unpack_from("<sB", self)
dtype = BinaryVectorDtype(dtype)
n_values = len(self) - position
offset = 2
n_bytes = len(self) - offset
if padding and dtype != BinaryVectorDtype.PACKED_BIT:
raise ValueError(
f"Corrupt data. Padding ({padding}) must be 0 for all but PACKED_BIT dtypes. ({dtype=})"
)
if dtype == BinaryVectorDtype.INT8:
dtype_format = "b"
format_string = f"<{n_values}{dtype_format}"
vector = list(struct.unpack_from(format_string, self, position))
return BinaryVector(vector, dtype, padding)
if not return_numpy:
if dtype == BinaryVectorDtype.INT8:
dtype_format = "b"
format_string = f"<{n_bytes}{dtype_format}"
vector = list(struct.unpack_from(format_string, self, offset))
return BinaryVector(vector, dtype, padding)
elif dtype == BinaryVectorDtype.FLOAT32:
n_bytes = len(self) - position
n_values = n_bytes // 4
if n_bytes % 4:
raise ValueError(
"Corrupt data. N bytes for a float32 vector must be a multiple of 4."
)
dtype_format = "f"
format_string = f"<{n_values}{dtype_format}"
vector = list(struct.unpack_from(format_string, self, position))
return BinaryVector(vector, dtype, padding)
elif dtype == BinaryVectorDtype.FLOAT32:
n_values = n_bytes // 4
if n_bytes % 4:
raise ValueError(
"Corrupt data. N bytes for a float32 vector must be a multiple of 4."
)
dtype_format = "f"
format_string = f"<{n_values}{dtype_format}"
vector = list(struct.unpack_from(format_string, self, offset))
return BinaryVector(vector, dtype, padding)
elif dtype == BinaryVectorDtype.PACKED_BIT:
# data packed as uint8
if padding and not n_values:
raise ValueError("Corrupt data. Vector has a padding P, but no data.")
if padding > 7 or padding < 0:
raise ValueError(f"Corrupt data. Padding ({padding}) must be between 0 and 7.")
dtype_format = "B"
format_string = f"<{n_values}{dtype_format}"
unpacked_uint8s = list(struct.unpack_from(format_string, self, position))
if padding and n_values and unpacked_uint8s[-1] & (1 << padding) - 1 != 0:
warnings.warn(
"Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero.",
DeprecationWarning,
stacklevel=2,
)
return BinaryVector(unpacked_uint8s, dtype, padding)
elif dtype == BinaryVectorDtype.PACKED_BIT:
# data packed as uint8
if padding and not n_bytes:
raise ValueError("Corrupt data. Vector has a padding P, but no data.")
if padding > 7 or padding < 0:
raise ValueError(f"Corrupt data. Padding ({padding}) must be between 0 and 7.")
dtype_format = "B"
format_string = f"<{n_bytes}{dtype_format}"
unpacked_uint8s = list(struct.unpack_from(format_string, self, offset))
if padding and n_bytes and unpacked_uint8s[-1] & (1 << padding) - 1 != 0:
warnings.warn(
"Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero.",
DeprecationWarning,
stacklevel=2,
)
return BinaryVector(unpacked_uint8s, dtype, padding)
else:
raise NotImplementedError("Binary Vector dtype %s not yet supported" % dtype.name)
else:
raise NotImplementedError("Binary Vector dtype %s not yet supported" % dtype.name)
else: # create a numpy array
try:
import numpy as np
except ImportError as exc:
raise ImportError(
"Converting binary to numpy.ndarray requires numpy to be installed."
) from exc
if dtype == BinaryVectorDtype.INT8:
data = np.frombuffer(self[offset:], dtype="int8")
elif dtype == BinaryVectorDtype.FLOAT32:
if n_bytes % 4:
raise ValueError(
"Corrupt data. N bytes for a float32 vector must be a multiple of 4."
)
data = np.frombuffer(self[offset:], dtype="float32")
elif dtype == BinaryVectorDtype.PACKED_BIT:
# data packed as uint8
if padding and not n_bytes:
raise ValueError("Corrupt data. Vector has a padding P, but no data.")
if padding > 7 or padding < 0:
raise ValueError(f"Corrupt data. Padding ({padding}) must be between 0 and 7.")
data = np.frombuffer(self[offset:], dtype="uint8")
if padding and np.unpackbits(data[-1])[-padding:].sum() > 0:
warnings.warn(
"Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero.",
DeprecationWarning,
stacklevel=2,
)
else:
raise NotImplementedError("Binary Vector dtype %s not yet supported" % dtype.name)
return BinaryVector(data, dtype, padding)
@property
def subtype(self) -> int:

View File

@ -273,9 +273,6 @@ if TYPE_CHECKING:
def _arguments_repr(self) -> str:
...
def _options_dict(self) -> dict[Any, Any]:
...
# NamedTuple API
@classmethod
def _make(cls, obj: Iterable[Any]) -> CodecOptions[_DocumentType]:
@ -466,19 +463,6 @@ else:
)
)
def _options_dict(self) -> dict[str, Any]:
"""Dictionary of the arguments used to create this object."""
# TODO: PYTHON-2442 use _asdict() instead
return {
"document_class": self.document_class,
"tz_aware": self.tz_aware,
"uuid_representation": self.uuid_representation,
"unicode_decode_error_handler": self.unicode_decode_error_handler,
"tzinfo": self.tzinfo,
"type_registry": self.type_registry,
"datetime_conversion": self.datetime_conversion,
}
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._arguments_repr()})"
@ -494,7 +478,7 @@ else:
.. versionadded:: 3.5
"""
opts = self._options_dict()
opts = self._asdict()
opts.update(kwargs)
return CodecOptions(**opts)

View File

@ -20,8 +20,11 @@ from __future__ import annotations
import decimal
import struct
from decimal import Decimal
from typing import Any, Sequence, Tuple, Type, Union
from bson.codec_options import TypeDecoder, TypeEncoder
_PACK_64 = struct.Struct("<Q").pack
_UNPACK_64 = struct.Struct("<Q").unpack
@ -58,6 +61,42 @@ _DEC128_CTX = decimal.Context(**_CTX_OPTIONS.copy()) # type: ignore
_VALUE_OPTIONS = Union[decimal.Decimal, float, str, Tuple[int, Sequence[int], int]]
class DecimalEncoder(TypeEncoder):
"""Converts Python :class:`decimal.Decimal` to BSON :class:`Decimal128`.
For example::
opts = CodecOptions(type_registry=TypeRegistry([DecimalEncoder()]))
bson.encode({"d": decimal.Decimal('1.0')}, codec_options=opts)
.. versionadded:: 4.15
"""
@property
def python_type(self) -> Type[Decimal]:
return Decimal
def transform_python(self, value: Any) -> Decimal128:
return Decimal128(value)
class DecimalDecoder(TypeDecoder):
"""Converts BSON :class:`Decimal128` to Python :class:`decimal.Decimal`.
For example::
opts = CodecOptions(type_registry=TypeRegistry([DecimalDecoder()]))
bson.decode(data, codec_options=opts)
.. versionadded:: 4.15
"""
@property
def bson_type(self) -> Type[Decimal128]:
return Decimal128
def transform_bson(self, value: Any) -> decimal.Decimal:
return value.to_decimal()
def create_decimal128_context() -> decimal.Context:
"""Returns an instance of :class:`decimal.Context` appropriate
for working with IEEE-754 128-bit decimal floating point values.

View File

@ -15,6 +15,8 @@
"""Exceptions raised by the BSON package."""
from __future__ import annotations
from typing import Any, Optional
class BSONError(Exception):
"""Base class for all BSON exceptions."""
@ -31,6 +33,17 @@ class InvalidStringData(BSONError):
class InvalidDocument(BSONError):
"""Raised when trying to create a BSON object from an invalid document."""
def __init__(self, message: str, document: Optional[Any] = None) -> None:
super().__init__(message)
self._document = document
@property
def document(self) -> Any:
"""The invalid document that caused the error.
..versionadded:: 4.16"""
return self._document
class InvalidId(BSONError):
"""Raised when trying to create an ObjectId from invalid data."""

View File

@ -382,19 +382,6 @@ class JSONOptions(_BASE_CLASS):
)
)
def _options_dict(self) -> dict[Any, Any]:
# TODO: PYTHON-2442 use _asdict() instead
options_dict = super()._options_dict()
options_dict.update(
{
"strict_number_long": self.strict_number_long,
"datetime_representation": self.datetime_representation,
"strict_uuid": self.strict_uuid,
"json_mode": self.json_mode,
}
)
return options_dict
def with_options(self, **kwargs: Any) -> JSONOptions:
"""
Make a copy of this JSONOptions, overriding some options::
@ -408,7 +395,7 @@ class JSONOptions(_BASE_CLASS):
.. versionadded:: 3.12
"""
opts = self._options_dict()
opts = self._asdict()
for opt in ("strict_number_long", "datetime_representation", "strict_uuid", "json_mode"):
opts[opt] = kwargs.get(opt, getattr(self, opt))
opts.update(kwargs)

View File

@ -15,7 +15,6 @@
"""Tools for working with MongoDB ObjectIds."""
from __future__ import annotations
import binascii
import datetime
import os
import struct
@ -98,11 +97,27 @@ class ObjectId:
objectid.rst>`_.
"""
if oid is None:
self.__generate()
# Generate a new value for this ObjectId.
with ObjectId._inc_lock:
inc = ObjectId._inc
ObjectId._inc = (inc + 1) % (_MAX_COUNTER_VALUE + 1)
# 4 bytes current time, 5 bytes random, 3 bytes inc.
self.__id = _PACK_INT_RANDOM(int(time.time()), ObjectId._random()) + _PACK_INT(inc)[1:4]
elif isinstance(oid, bytes) and len(oid) == 12:
self.__id = oid
elif isinstance(oid, str):
if len(oid) == 24:
try:
self.__id = bytes.fromhex(oid)
except (TypeError, ValueError):
_raise_invalid_id(oid)
else:
_raise_invalid_id(oid)
elif isinstance(oid, ObjectId):
self.__id = oid.binary
else:
self.__validate(oid)
raise TypeError(f"id must be an instance of (bytes, str, ObjectId), not {type(oid)}")
@classmethod
def from_datetime(cls: Type[ObjectId], generation_time: datetime.datetime) -> ObjectId:
@ -163,37 +178,6 @@ class ObjectId:
cls.__random = _random_bytes()
return cls.__random
def __generate(self) -> None:
"""Generate a new value for this ObjectId."""
with ObjectId._inc_lock:
inc = ObjectId._inc
ObjectId._inc = (inc + 1) % (_MAX_COUNTER_VALUE + 1)
# 4 bytes current time, 5 bytes random, 3 bytes inc.
self.__id = _PACK_INT_RANDOM(int(time.time()), ObjectId._random()) + _PACK_INT(inc)[1:4]
def __validate(self, oid: Any) -> None:
"""Validate and use the given id for this ObjectId.
Raises TypeError if id is not an instance of :class:`str`,
:class:`bytes`, or ObjectId. Raises InvalidId if it is not a
valid ObjectId.
:param oid: a valid ObjectId
"""
if isinstance(oid, ObjectId):
self.__id = oid.binary
elif isinstance(oid, str):
if len(oid) == 24:
try:
self.__id = bytes.fromhex(oid)
except (TypeError, ValueError):
_raise_invalid_id(oid)
else:
_raise_invalid_id(oid)
else:
raise TypeError(f"id must be an instance of (bytes, str, ObjectId), not {type(oid)}")
@property
def binary(self) -> bytes:
"""12-byte binary representation of this ObjectId."""
@ -234,7 +218,7 @@ class ObjectId:
self.__id = oid
def __str__(self) -> str:
return binascii.hexlify(self.__id).decode()
return self.__id.hex()
def __repr__(self) -> str:
return f"ObjectId('{self!s}')"

View File

@ -60,7 +60,9 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS as DEFAULT
def _inflate_bson(
bson_bytes: bytes, codec_options: CodecOptions[RawBSONDocument], raw_array: bool = False
bson_bytes: bytes | memoryview,
codec_options: CodecOptions[RawBSONDocument],
raw_array: bool = False,
) -> dict[str, Any]:
"""Inflates the top level fields of a BSON document.
@ -85,7 +87,9 @@ class RawBSONDocument(Mapping[str, Any]):
__codec_options: CodecOptions[RawBSONDocument]
def __init__(
self, bson_bytes: bytes, codec_options: Optional[CodecOptions[RawBSONDocument]] = None
self,
bson_bytes: bytes | memoryview,
codec_options: Optional[CodecOptions[RawBSONDocument]] = None,
) -> None:
"""Create a new :class:`RawBSONDocument`
@ -135,7 +139,7 @@ class RawBSONDocument(Mapping[str, Any]):
_get_object_size(bson_bytes, 0, len(bson_bytes))
@property
def raw(self) -> bytes:
def raw(self) -> bytes | memoryview:
"""The raw BSON bytes composing this document."""
return self.__raw
@ -153,7 +157,7 @@ class RawBSONDocument(Mapping[str, Any]):
@staticmethod
def _inflate_bson(
bson_bytes: bytes, codec_options: CodecOptions[RawBSONDocument]
bson_bytes: bytes | memoryview, codec_options: CodecOptions[RawBSONDocument]
) -> Mapping[str, Any]:
return _inflate_bson(bson_bytes, codec_options)
@ -180,7 +184,7 @@ class _RawArrayBSONDocument(RawBSONDocument):
@staticmethod
def _inflate_bson(
bson_bytes: bytes, codec_options: CodecOptions[RawBSONDocument]
bson_bytes: bytes | memoryview, codec_options: CodecOptions[RawBSONDocument]
) -> Mapping[str, Any]:
return _inflate_bson(bson_bytes, codec_options, raw_array=True)

View File

@ -143,7 +143,7 @@ class SON(Dict[_Key, _Value]):
del self[k]
return (k, v)
def update(self, other: Optional[Any] = None, **kwargs: _Value) -> None: # type: ignore[override]
def update(self, other: Optional[Any] = None, **kwargs: _Value) -> None:
# Make progressively weaker assumptions about "other"
if other is None:
pass

View File

@ -28,4 +28,4 @@ if TYPE_CHECKING:
_DocumentOut = Union[MutableMapping[str, Any], "RawBSONDocument"]
_DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any])
_DocumentTypeArg = TypeVar("_DocumentTypeArg", bound=Mapping[str, Any])
_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"] # type: ignore[type-arg]
_ReadableBuffer = Union[bytes, memoryview, bytearray, "mmap", "array"] # type: ignore[type-arg]

View File

@ -5,3 +5,4 @@
.. automodule:: pymongo.asynchronous.command_cursor
:synopsis: Tools for iterating over MongoDB command results
:members:
:inherited-members:

View File

@ -7,6 +7,8 @@
.. autoclass:: pymongo.asynchronous.cursor.AsyncCursor(collection, filter=None, projection=None, skip=0, limit=0, no_cursor_timeout=False, cursor_type=CursorType.NON_TAILABLE, sort=None, allow_partial_results=False, oplog_replay=False, batch_size=0, collation=None, hint=None, max_scan=None, max_time_ms=None, max=None, min=None, return_key=False, show_record_id=False, snapshot=False, comment=None, session=None, allow_disk_use=None)
:members:
:inherited-members:
.. describe:: c[index]

View File

@ -4,3 +4,4 @@
.. automodule:: pymongo.command_cursor
:synopsis: Tools for iterating over MongoDB command results
:members:
:inherited-members:

View File

@ -17,6 +17,7 @@
.. autoclass:: pymongo.cursor.Cursor(collection, filter=None, projection=None, skip=0, limit=0, no_cursor_timeout=False, cursor_type=CursorType.NON_TAILABLE, sort=None, allow_partial_results=False, oplog_replay=False, batch_size=0, collation=None, hint=None, max_scan=None, max_time_ms=None, max=None, min=None, return_key=False, show_record_id=False, snapshot=False, comment=None, session=None, allow_disk_use=None)
:members:
:inherited-members:
.. describe:: c[index]

View File

@ -1,6 +1,148 @@
Changelog
=========
Changes in Version 4.17.0 (2026/XX/XX)
--------------------------------------
PyMongo 4.17 brings a number of changes including:
- Added the :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.bind` and :meth:`~pymongo.client_session.ClientSession.bind` methods
that allow users to bind a session to all database operations within the scope of a context manager instead of having to explicitly pass the session to each individual operation.
See <PLACEHOLDER> for examples and more information.
Changes in Version 4.16.0 (2026/01/07)
--------------------------------------
PyMongo 4.16 brings a number of changes including:
- Removed invalid documents from :class:`bson.errors.InvalidDocument` error messages as
doing so may leak sensitive user data.
Instead, invalid documents are stored in :attr:`bson.errors.InvalidDocument.document`.
- PyMongo now requires ``dnspython>=2.6.1``, since ``dnspython`` 1.0 is no longer maintained.
The minimum version is ``2.6.1`` to account for `CVE-2023-29483 <https://www.cve.org/CVERecord?id=CVE-2023-29483>`_.
- Removed support for Eventlet.
Eventlet is actively being sunset by its maintainers and has compatibility issues with PyMongo's dnspython dependency.
- Use Zstandard support from the standard library for Python 3.14+, and use ``backports.zstd`` for older versions.
- Fixed return type annotation for ``find_one_and_*`` methods on :class:`~pymongo.asynchronous.collection.AsyncCollection`
and :class:`~pymongo.synchronous.collection.Collection` to include ``None``.
- Added support for NumPy 1D-arrays in :class:`bson.binary.BinaryVector`.
- Prevented :class:`~pymongo.encryption.ClientEncryption` from loading the crypt
shared library to fix "MongoCryptError: An existing crypt_shared library is
loaded by the application" unless the linked library search path is set.
Changes in Version 4.15.5 (2025/12/02)
--------------------------------------
Version 4.15.5 is a bug fix release.
- Fixed a bug that could cause ``AutoReconnect("connection pool paused")`` errors when cursors fetched more documents from the database after SDAM heartbeat failures.
Changes in Version 4.15.4 (2025/10/21)
--------------------------------------
Version 4.15.4 is a bug fix release.
- Relaxed the callback type of :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.with_transaction` to allow the broader Awaitable type rather than only Coroutine objects.
- Added the missing Python 3.14 trove classifier to the package metadata.
Issues Resolved
...............
See the `PyMongo 4.15.4 release notes in JIRA`_ for the list of resolved issues
in this release.
.. _PyMongo 4.15.4 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=47237
Changes in Version 4.15.3 (2025/10/07)
--------------------------------------
Version 4.15.3 is a bug fix release.
- Fixed a memory leak when raising :class:`bson.errors.InvalidDocument` with C extensions.
- Fixed the return type of the :meth:`~pymongo.asynchronous.collection.AsyncCollection.distinct`,
:meth:`~pymongo.synchronous.collection.Collection.distinct`, :meth:`pymongo.asynchronous.cursor.AsyncCursor.distinct`,
and :meth:`pymongo.asynchronous.cursor.AsyncCursor.distinct` methods.
Issues Resolved
...............
See the `PyMongo 4.15.3 release notes in JIRA`_ for the list of resolved issues
in this release.
.. _PyMongo 4.15.3 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=47293
Changes in Version 4.15.2 (2025/10/01)
--------------------------------------
Version 4.15.2 is a bug fix release.
- Add wheels for Python 3.14 and 3.14t that were missing from 4.15.0 release. Drop the 3.13t wheel.
Issues Resolved
...............
See the `PyMongo 4.15.2 release notes in JIRA`_ for the list of resolved issues
in this release.
.. _PyMongo 4.15.2 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=47186
Changes in Version 4.15.1 (2025/09/16)
--------------------------------------
Version 4.15.1 is a bug fix release.
- Fixed a bug in :meth:`~pymongo.synchronous.encryption.ClientEncryption.encrypt`
and :meth:`~pymongo.asynchronous.encryption.AsyncClientEncryption.encrypt`
that would cause a ``TypeError`` when using ``pymongocrypt<1.16`` by passing
an unsupported ``type_opts`` parameter even if Queryable Encryption text
queries beta was not used.
- Fixed a bug in ``AsyncMongoClient`` that caused a ``ServerSelectionTimeoutError``
when used with ``uvicorn``, ``FastAPI``, or ``uvloop``.
Issues Resolved
...............
See the `PyMongo 4.15.1 release notes in JIRA`_ for the list of resolved issues
in this release.
.. _PyMongo 4.15.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=46486
Changes in Version 4.15.0 (2025/09/10)
--------------------------------------
PyMongo 4.15 brings a number of changes including:
- Added :class:`~pymongo.encryption_options.TextOpts`,
:attr:`~pymongo.encryption.Algorithm.TEXTPREVIEW`,
:attr:`~pymongo.encryption.QueryType.PREFIXPREVIEW`,
:attr:`~pymongo.encryption.QueryType.SUFFIXPREVIEW`,
:attr:`~pymongo.encryption.QueryType.SUBSTRINGPREVIEW`,
as part of the experimental Queryable Encryption text queries beta.
``pymongocrypt>=1.16`` is required for text query support.
- Added :class:`bson.decimal128.DecimalEncoder` and
:class:`bson.decimal128.DecimalDecoder`
to support encoding and decoding of BSON Decimal128 values to
decimal.Decimal values using the TypeRegistry API.
- Added support for Windows ``arm64`` wheels.
Changes in Version 4.14.1 (2025/08/19)
--------------------------------------
Version 4.14.1 is a bug fix release.
- Fixed a bug in ``MongoClient.append_metadata()`` and
``AsyncMongoClient.append_metadata()``
that allowed duplicate ``DriverInfo.name`` to be appended to the metadata.
Issues Resolved
...............
See the `PyMongo 4.14.1 release notes in JIRA`_ for the list of resolved issues
in this release.
.. _PyMongo 4.14.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=45256
Changes in Version 4.14.0 (2025/08/06)
--------------------------------------
@ -34,6 +176,14 @@ PyMongo 4.14 brings a number of changes including:
- Changed :meth:`~pymongo.uri_parser.parse_uri`'s ``options`` return value to be
type ``dict`` instead of ``_CaseInsensitiveDictionary``.
Issues Resolved
...............
See the `PyMongo 4.14 release notes in JIRA`_ for the list of resolved issues
in this release.
.. _PyMongo 4.14 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=43041
Changes in Version 4.13.2 (2025/06/17)
--------------------------------------

View File

@ -84,12 +84,16 @@ pygments_style = "sphinx"
# so those link results in a 404.
# wiki.centos.org has been flaky.
# sourceforge.net is giving a 403 error, but is still accessible from the browser.
# Links to release notes in jira give 401 error: unauthorized. PYTHON-5585
linkcheck_ignore = [
"https://github.com/mongodb/specifications/blob/master/source/server-discovery-and-monitoring/server-monitoring.md#requesting-an-immediate-check",
"https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback",
"https://github.com/mongodb/specifications/blob/master/source/uri-options/uri-options.md",
"https://github.com/mongodb/specifications/blob/master/source/uri-options/uri-options.md",
"https://github.com/mongodb/libmongocrypt/blob/master/bindings/python/README.rst#installing-from-source",
r"https://wiki.centos.org/[\w/]*",
r"https://sourceforge.net/",
r"https://jira\.mongodb\.org/secure/ReleaseNote\.jspa.*",
]
# Allow for flaky links.
@ -184,8 +188,8 @@ latex_documents = [
("index", "PyMongo.tex", "PyMongo Documentation", "Michael Dirolf", "manual"),
]
# The name of an image file (relative to this directory) to place at the top of
# the title page.
# The name of an image file (relative to this directory) to place at the top
# of the title page.
# latex_logo = None
# For "manual" documents, if this is true, then toplevel headings are parts,

View File

@ -103,3 +103,8 @@ The following is a list of people who have contributed to
- Terry Patterson
- Romain Morotti
- Navjot Singh (navjots18)
- Jib Adegunloye (Jibola)
- Jeffrey A. Clark (aclark4life)
- Steven Silvester (blink1073)
- Noah Stapp (NoahStapp)
- Cal Jacobson (cj81499)

View File

@ -22,7 +22,7 @@ work with MongoDB from Python.
Getting Help
------------
If you're having trouble or have questions about PyMongo, ask your question on
our `MongoDB Community Forum <https://www.mongodb.com/community/forums/tag/python>`_.
one of the platforms listed on `Technical Support <https://www.mongodb.com/docs/manual/support/>`_.
You may also want to consider a
`commercial support subscription <https://support.mongodb.com/welcome>`_.
Once you get an answer, it'd be great if you could work it back into this
@ -37,7 +37,7 @@ project.
Feature Requests / Feedback
---------------------------
Use our `feedback engine <https://feedback.mongodb.com/forums/924286-drivers>`_
Use our `feedback engine <https://feedback.mongodb.com/?category=7548141816650747033>`_
to send us feature requests and general feedback about PyMongo.
Contributing

View File

@ -0,0 +1,42 @@
# Integration Tests
A set of tests that verify the usage of PyMongo with downstream packages or frameworks.
Each test uses [PEP 723 inline metadata](https://packaging.python.org/en/latest/specifications/inline-script-metadata/) and can be run using `pipx` or `uv`.
The `run.sh` convenience script can be used to run all of the files using `uv`.
Here is an example header for the script with the inline dependencies:
```python
# /// script
# dependencies = [
# "uvloop>=0.18"
# ]
# requires-python = ">=3.10"
# ///
```
Here is an example of using the test helper function to create a configured client for the test:
```python
import asyncio
import sys
from pathlib import Path
# Use pymongo from parent directory.
root = Path(__file__).parent.parent
sys.path.insert(0, str(root))
from test.asynchronous import async_simple_test_client # noqa: E402
async def main():
async with async_simple_test_client() as client:
result = await client.admin.command("ping")
assert result["ok"]
asyncio.run(main())
```

11
integration_tests/run.sh Executable file
View File

@ -0,0 +1,11 @@
#!/bin/bash
# Run all of the integration test files using `uv run`.
set -eu
for file in integration_tests/test_*.py ; do
echo "-----------------"
echo "Running $file..."
uv run $file
echo "Running $file...done."
echo "-----------------"
done

View File

@ -0,0 +1,27 @@
# /// script
# dependencies = [
# "uvloop>=0.18"
# ]
# requires-python = ">=3.10"
# ///
from __future__ import annotations
import sys
from pathlib import Path
import uvloop
# Use pymongo from parent directory.
root = Path(__file__).parent.parent
sys.path.insert(0, str(root))
from test.asynchronous import async_simple_test_client # noqa: E402
async def main():
async with async_simple_test_client() as client:
result = await client.admin.command("ping")
assert result["ok"]
uvloop.run(main())

View File

@ -1,10 +1,8 @@
# See https://just.systems/man/en/ for instructions
set shell := ["bash", "-c"]
# Do not modify the lock file when running justfile commands.
export UV_FROZEN := "1"
# Commonly used command segments.
typing_run := "uv run --group typing --extra aws --extra encryption --extra ocsp --extra snappy --extra test --extra zstd"
typing_run := "uv run --group typing --extra aws --extra encryption --with numpy --extra ocsp --extra snappy --extra test --extra zstd"
docs_run := "uv run --extra docs"
doc_build := "./doc/_build"
mypy_args := "--install-types --non-interactive"
@ -16,7 +14,7 @@ default:
[private]
resync:
@uv sync --quiet --frozen
@uv sync --quiet
install:
bash .evergreen/scripts/setup-dev-env.sh
@ -40,26 +38,33 @@ typing: && resync
[group('typing')]
typing-mypy: && resync
{{typing_run}} mypy {{mypy_args}} bson gridfs tools pymongo
{{typing_run}} mypy {{mypy_args}} --config-file mypy_test.ini test
{{typing_run}} mypy {{mypy_args}} test/test_typing.py test/test_typing_strict.py
{{typing_run}} python -m mypy {{mypy_args}} bson gridfs tools pymongo
{{typing_run}} python -m mypy {{mypy_args}} --config-file mypy_test.ini test
{{typing_run}} python -m mypy {{mypy_args}} test/test_typing.py test/test_typing_strict.py
[group('typing')]
typing-pyright: && resync
{{typing_run}} pyright test/test_typing.py test/test_typing_strict.py
{{typing_run}} pyright -p strict_pyrightconfig.json test/test_typing_strict.py
{{typing_run}} python -m pyright test/test_typing.py test/test_typing_strict.py
{{typing_run}} python -m pyright -p strict_pyrightconfig.json test/test_typing_strict.py
[group('lint')]
lint: && resync
uv run pre-commit run --all-files
lint *args="": && resync
uvx pre-commit run --all-files {{args}}
[group('lint')]
lint-manual: && resync
uv run pre-commit run --all-files --hook-stage manual
lint-manual *args="": && resync
uvx pre-commit run --all-files --hook-stage manual {{args}}
[group('test')]
test *args="-v --durations=5 --maxfail=10": && resync
uv run --extra test pytest {{args}}
#!/usr/bin/env bash
set -euo pipefail
uv run ${USE_ACTIVE_VENV:+--active} --extra test python -m pytest {{args}}
[group('test')]
test-numpy *args="": && resync
just setup-tests numpy {{args}}
just run-tests test/test_bson.py
[group('test')]
run-tests *args: && resync
@ -73,6 +78,29 @@ setup-tests *args="":
teardown-tests:
bash .evergreen/scripts/teardown-tests.sh
[group('test')]
integration-tests:
bash integration_tests/run.sh
[group('test')]
test-coverage *args="":
just setup-tests --cov
just run-tests {{args}}
[group('coverage')]
coverage-report:
uv tool run --with "coverage[toml]" coverage report
[group('coverage')]
coverage-html:
uv tool run --with "coverage[toml]" coverage html
@echo "Coverage report generated in htmlcov/index.html"
[group('coverage')]
coverage-xml:
uv tool run --with "coverage[toml]" coverage xml
@echo "Coverage report generated in coverage.xml"
[group('server')]
run-server *args="":
bash .evergreen/scripts/run-server.sh {{args}}

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import re
from typing import List, Tuple, Union
__version__ = "4.15.0.dev0"
__version__ = "4.17.0.dev0"
def get_version_tuple(version: str) -> Tuple[Union[int, str], ...]:

View File

@ -50,7 +50,6 @@ class _AggregationCommand:
cursor_class: type[AsyncCommandCursor[Any]],
pipeline: _Pipeline,
options: MutableMapping[str, Any],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
user_fields: Optional[MutableMapping[str, Any]] = None,
result_processor: Optional[Callable[[Mapping[str, Any], AsyncConnection], None]] = None,
@ -92,7 +91,6 @@ class _AggregationCommand:
self._options["cursor"]["batchSize"] = self._batch_size
self._cursor_class = cursor_class
self._explicit_session = explicit_session
self._user_fields = user_fields
self._result_processor = result_processor
@ -197,7 +195,6 @@ class _AggregationCommand:
batch_size=self._batch_size or 0,
max_await_time_ms=self._max_await_time_ms,
session=session,
explicit_session=self._explicit_session,
comment=self._options.get("comment"),
)
await cmd_cursor._maybe_pin_connection(conn)

View File

@ -236,7 +236,7 @@ class AsyncChangeStream(Generic[_DocumentType]):
)
async def _run_aggregation_cmd(
self, session: Optional[AsyncClientSession], explicit_session: bool
self, session: Optional[AsyncClientSession]
) -> AsyncCommandCursor: # type: ignore[type-arg]
"""Run the full aggregation pipeline for this AsyncChangeStream and return
the corresponding AsyncCommandCursor.
@ -246,7 +246,6 @@ class AsyncChangeStream(Generic[_DocumentType]):
AsyncCommandCursor,
self._aggregation_pipeline(),
self._command_options(),
explicit_session,
result_processor=self._process_result,
comment=self._comment,
)
@ -258,10 +257,8 @@ class AsyncChangeStream(Generic[_DocumentType]):
)
async def _create_cursor(self) -> AsyncCommandCursor: # type: ignore[type-arg]
async with self._client._tmp_session(self._session, close=False) as s:
return await self._run_aggregation_cmd(
session=s, explicit_session=self._session is not None
)
async with self._client._tmp_session(self._session) as s:
return await self._run_aggregation_cmd(session=s)
async def _resume(self) -> None:
"""Reestablish this change stream after a resumable error."""

View File

@ -440,6 +440,8 @@ class _AsyncClientBulk:
) -> None:
"""Internal helper for processing the server reply command cursor."""
if result.get("cursor"):
if session:
session._leave_alive = True
coll = AsyncCollection(
database=AsyncDatabase(self.client, "admin"),
name="$cmd.bulkWrite",
@ -449,7 +451,6 @@ class _AsyncClientBulk:
result["cursor"],
conn.address,
session=session,
explicit_session=session is not None,
comment=self.comment,
)
await cmd_cursor._maybe_pin_connection(conn)
@ -562,9 +563,21 @@ class _AsyncClientBulk:
error, ConnectionFailure
) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError))
retryable_label_error = (
hasattr(error, "details")
and isinstance(error.details, dict)
and "errorLabels" in error.details
and isinstance(error.details["errorLabels"], list)
and "RetryableError" in error.details["errorLabels"]
)
# Synthesize the full bulk result without modifying the
# current one because this write operation may be retried.
if retryable and (retryable_top_level_error or retryable_network_error):
if retryable and (
retryable_top_level_error
or retryable_network_error
or retryable_label_error
):
full = copy.deepcopy(full_result)
_merge_command(self.ops, self.idx_offset, full, result)
_throw_client_bulk_write_exception(full, self.verbose_results)

View File

@ -141,12 +141,13 @@ import random
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar, Token
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Awaitable,
Callable,
Coroutine,
Mapping,
MutableMapping,
NoReturn,
@ -159,17 +160,18 @@ from bson.binary import Binary
from bson.int64 import Int64
from bson.timestamp import Timestamp
from pymongo import _csot
from pymongo.asynchronous.cursor import _ConnectionManager
from pymongo.asynchronous.cursor_base import _ConnectionManager
from pymongo.errors import (
ConfigurationError,
ConnectionFailure,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
OperationFailure,
PyMongoError,
WTimeoutError,
)
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
from pymongo.operations import _Op
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.server_type import SERVER_TYPE
@ -184,6 +186,28 @@ if TYPE_CHECKING:
_IS_SYNC = False
_SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None)
class _AsyncBoundSessionContext:
"""Context manager returned by AsyncClientSession.bind() that manages bound state."""
def __init__(self, session: AsyncClientSession, end_session: bool) -> None:
self._session = session
self._session_token: Optional[Token[AsyncClientSession]] = None
self._end_session = end_session
async def __aenter__(self) -> AsyncClientSession:
self._session_token = _SESSION.set(self._session) # type: ignore[assignment]
return self._session
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._session_token:
_SESSION.reset(self._session_token) # type: ignore[arg-type]
self._session_token = None
if self._end_session:
await self._session.end_session()
class SessionOptions:
"""Options for a new :class:`AsyncClientSession`.
@ -407,6 +431,7 @@ class _Transaction:
self.recovery_token = None
self.attempt = 0
self.client = client
self.has_completed_command = False
def active(self) -> bool:
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
@ -414,6 +439,9 @@ class _Transaction:
def starting(self) -> bool:
return self.state == _TxnState.STARTING
def set_starting(self) -> None:
self.state = _TxnState.STARTING
@property
def pinned_conn(self) -> Optional[AsyncConnection]:
if self.active() and self.conn_mgr:
@ -473,13 +501,24 @@ _UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( #
# This limit is non-configurable and was chosen to be twice the 60 second
# default value of MongoDB's `transactionLifetimeLimitSeconds` parameter.
_WITH_TRANSACTION_RETRY_TIME_LIMIT = 120
_BACKOFF_MAX = 1
_BACKOFF_INITIAL = 0.050 # 50ms initial backoff
_BACKOFF_MAX = 0.500 # 500ms max backoff
_BACKOFF_INITIAL = 0.005 # 5ms initial backoff
def _within_time_limit(start_time: float) -> bool:
def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
"""Are we within the with_transaction retry limit?"""
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
remaining = _csot.remaining()
if remaining is not None and remaining <= 0:
return False
return time.monotonic() + backoff - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
def _make_timeout_error(error: BaseException) -> PyMongoError:
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
if _csot.remaining() is not None:
return ExecutionTimeout(str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50})
else:
return NetworkTimeout(str(error))
_T = TypeVar("_T")
@ -518,6 +557,10 @@ class AsyncClientSession:
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None, client)
# Is this session attached to a cursor?
self._attached_to_cursor = False
# Should we leave the session alive when the cursor is closed?
self._leave_alive = False
async def end_session(self) -> None:
"""Finish this session. If a transaction has started, abort it.
@ -540,7 +583,7 @@ class AsyncClientSession:
def _end_implicit_session(self) -> None:
# Implicit sessions can't be part of transactions or pinned connections
if self._server_session is not None:
if not self._leave_alive and self._server_session is not None:
self._client._return_server_session(self._server_session)
self._server_session = None
@ -548,6 +591,24 @@ class AsyncClientSession:
if self._server_session is None:
raise InvalidOperation("Cannot use ended session")
def bind(self, end_session: bool = True) -> _AsyncBoundSessionContext:
"""Bind this session so it is implicitly passed to all database operations within the returned context.
.. code-block:: python
async with client.start_session() as s:
async with s.bind():
# session=s is passed implicitly
await client.db.collection.insert_one({"x": 1})
:param end_session: Whether to end the session on exiting the returned context. Defaults to True.
If set to False, :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.end_session()` must be called
once the session is no longer used.
.. versionadded:: 4.17
"""
return _AsyncBoundSessionContext(self, end_session)
async def __aenter__(self) -> AsyncClientSession:
return self
@ -605,7 +666,7 @@ class AsyncClientSession:
async def with_transaction(
self,
callback: Callable[[AsyncClientSession], Coroutine[Any, Any, _T]],
callback: Callable[[AsyncClientSession], Awaitable[_T]],
read_concern: Optional[ReadConcern] = None,
write_concern: Optional[WriteConcern] = None,
read_preference: Optional[_ServerMode] = None,
@ -705,10 +766,14 @@ class AsyncClientSession:
"""
start_time = time.monotonic()
retry = 0
last_error: Optional[BaseException] = None
while True:
if retry: # Implement exponential backoff on retry.
jitter = random.random() # noqa: S311
backoff = jitter * min(_BACKOFF_INITIAL * (2**retry), _BACKOFF_MAX)
backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX)
if not _within_time_limit(start_time, backoff):
assert last_error is not None
raise _make_timeout_error(last_error) from last_error
await asyncio.sleep(backoff)
retry += 1
await self.start_transaction(
@ -718,15 +783,16 @@ class AsyncClientSession:
ret = await callback(self)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as exc:
last_error = exc
if self.in_transaction:
await self.abort_transaction()
if (
isinstance(exc, PyMongoError)
and exc.has_error_label("TransientTransactionError")
and _within_time_limit(start_time)
if isinstance(exc, PyMongoError) and exc.has_error_label(
"TransientTransactionError"
):
# Retry the entire transaction.
continue
if _within_time_limit(start_time):
# Retry the entire transaction.
continue
raise _make_timeout_error(last_error) from exc
raise
if not self.in_transaction:
@ -737,17 +803,16 @@ class AsyncClientSession:
try:
await self.commit_transaction()
except PyMongoError as exc:
if (
exc.has_error_label("UnknownTransactionCommitResult")
and _within_time_limit(start_time)
and not _max_time_expired_error(exc)
):
last_error = exc
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
if exc.has_error_label(
"UnknownTransactionCommitResult"
) and not _max_time_expired_error(exc):
# Retry the commit.
continue
if exc.has_error_label("TransientTransactionError") and _within_time_limit(
start_time
):
if exc.has_error_label("TransientTransactionError"):
# Retry the entire transaction.
break
raise
@ -878,7 +943,7 @@ class AsyncClientSession:
return await self._finish_transaction(conn, command_name)
return await self._client._retry_internal(
func, self, None, retryable=True, operation=_Op.ABORT
func, self, None, retryable=True, operation=command_name
)
async def _finish_transaction(self, conn: AsyncConnection, command_name: str) -> dict[str, Any]:

View File

@ -20,7 +20,6 @@ from collections import abc
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Callable,
Coroutine,
Generic,
@ -573,11 +572,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
await change_stream._initialize_cursor()
return change_stream
async def _conn_for_writes(
self, session: Optional[AsyncClientSession], operation: str
) -> AsyncContextManager[AsyncConnection]:
return await self._database.client._conn_for_writes(session, operation)
async def _command(
self,
conn: AsyncConnection,
@ -654,7 +648,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
if "size" in options:
options["size"] = float(options["size"])
cmd.update(options)
async with await self._conn_for_writes(session, operation=_Op.CREATE) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
if qev2_required and conn.max_wire_version < 21:
raise ConfigurationError(
"Driver support of Queryable Encryption is incompatible with server. "
@ -671,6 +668,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
session=session,
)
await self.database.client._retryable_write(False, inner, session, _Op.CREATE)
async def _create(
self,
options: MutableMapping[str, Any],
@ -2146,11 +2145,9 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
kwargs["comment"] = comment
pipeline.append({"$group": {"_id": 1, "n": {"$sum": 1}}})
cmd = {"aggregate": self._name, "pipeline": pipeline, "cursor": {}}
if "hint" in kwargs and not isinstance(kwargs["hint"], str):
kwargs["hint"] = helpers_shared._index_document(kwargs["hint"])
collation = validate_collation_or_none(kwargs.pop("collation", None))
cmd.update(kwargs)
async def _cmd(
session: Optional[AsyncClientSession],
@ -2158,6 +2155,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
conn: AsyncConnection,
read_preference: Optional[_ServerMode],
) -> int:
cmd: dict[str, Any] = {"aggregate": self._name, "pipeline": pipeline, "cursor": {}}
cmd.update(kwargs)
result = await self._aggregate_one_result(
conn, read_preference, cmd, collation, session
)
@ -2243,7 +2242,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
command (like maxTimeMS) can be passed as keyword arguments.
"""
names = []
async with await self._conn_for_writes(session, operation=_Op.CREATE_INDEXES) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> list[str]:
supports_quorum = conn.max_wire_version >= 9
def gen_indexes() -> Iterator[Mapping[str, Any]]:
@ -2272,7 +2274,11 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
write_concern=self._write_concern_for(session),
session=session,
)
return names
return names
return await self.database.client._retryable_write(
False, inner, session, _Op.CREATE_INDEXES
)
async def create_index(
self,
@ -2493,7 +2499,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.DROP_INDEXES) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command(
conn,
cmd,
@ -2503,6 +2512,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
session=session,
)
await self.database.client._retryable_write(False, inner, session, _Op.DROP_INDEXES)
async def list_indexes(
self,
session: Optional[AsyncClientSession] = None,
@ -2552,7 +2563,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY),
)
read_pref = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
explicit_session = session is not None
async def _cmd(
session: Optional[AsyncClientSession],
@ -2579,13 +2589,12 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cursor,
conn.address,
session=session,
explicit_session=explicit_session,
comment=cmd.get("comment"),
)
await cmd_cursor._maybe_pin_connection(conn)
return cmd_cursor
async with self._database.client._tmp_session(session, False) as s:
async with self._database.client._tmp_session(session) as s:
return await self._database.client._retryable_read(
_cmd, read_pref, s, operation=_Op.LIST_INDEXES
)
@ -2681,7 +2690,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
AsyncCommandCursor,
pipeline,
kwargs,
explicit_session=session is not None,
comment=comment,
user_fields={"cursor": {"firstBatch": 1}},
)
@ -2769,17 +2777,22 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd = {"createSearchIndexes": self.name, "indexes": list(gen_indexes())}
cmd.update(kwargs)
async with await self._conn_for_writes(
session, operation=_Op.CREATE_SEARCH_INDEXES
) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> list[str]:
resp = await self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
return [index["name"] for index in resp["indexesCreated"]]
return await self.database.client._retryable_write(
False, inner, session, _Op.CREATE_SEARCH_INDEXES
)
async def drop_search_index(
self,
name: str,
@ -2805,15 +2818,21 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.DROP_SEARCH_INDEXES) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
await self.database.client._retryable_write(False, inner, session, _Op.DROP_SEARCH_INDEXES)
async def update_search_index(
self,
name: str,
@ -2841,15 +2860,21 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.UPDATE_SEARCH_INDEX) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
await self.database.client._retryable_write(False, inner, session, _Op.UPDATE_SEARCH_INDEX)
async def options(
self,
session: Optional[AsyncClientSession] = None,
@ -2903,7 +2928,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
pipeline: _Pipeline,
cursor_class: Type[AsyncCommandCursor], # type: ignore[type-arg]
session: Optional[AsyncClientSession],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
@ -2915,7 +2939,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cursor_class,
pipeline,
kwargs,
explicit_session,
let,
user_fields={"cursor": {"firstBatch": 1}},
)
@ -2926,6 +2949,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
session,
retryable=not cmd._performs_write,
operation=_Op.AGGREGATE,
is_aggregate_write=cmd._performs_write,
)
async def aggregate(
@ -3021,13 +3045,12 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
.. _aggregate command:
https://mongodb.com/docs/manual/reference/command/aggregate
"""
async with self._database.client._tmp_session(session, close=False) as s:
async with self._database.client._tmp_session(session) as s:
return await self._aggregate(
_CollectionAggregationCommand,
pipeline,
AsyncCommandCursor,
session=s,
explicit_session=session is not None,
let=let,
comment=comment,
**kwargs,
@ -3068,7 +3091,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
raise InvalidOperation("aggregate_raw_batches does not support auto encryption")
if comment is not None:
kwargs["comment"] = comment
async with self._database.client._tmp_session(session, close=False) as s:
async with self._database.client._tmp_session(session) as s:
return cast(
AsyncRawBatchCursor[_DocumentType],
await self._aggregate(
@ -3076,7 +3099,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
pipeline,
AsyncRawBatchCommandCursor,
session=s,
explicit_session=session is not None,
**kwargs,
),
)
@ -3134,17 +3156,21 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
cmd["comment"] = comment
write_concern = self._write_concern_for_cmd(cmd, session)
client = self._database.client
async with await self._conn_for_writes(session, operation=_Op.RENAME) as conn:
async with self._database.client._tmp_session(session) as s:
return await conn.command(
"admin",
cmd,
write_concern=write_concern,
parse_write_concern_error=True,
session=s,
client=self._database.client,
)
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> MutableMapping[str, Any]:
return await conn.command(
"admin",
cmd,
write_concern=write_concern,
parse_write_concern_error=True,
session=session,
client=client,
)
return await client._retryable_write(False, inner, session, _Op.RENAME)
async def distinct(
self,
@ -3154,7 +3180,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
comment: Optional[Any] = None,
hint: Optional[_IndexKeyHint] = None,
**kwargs: Any,
) -> list[str]:
) -> list[Any]:
"""Get a list of distinct values for `key` among all documents
in this collection.
@ -3198,19 +3224,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
"""
if not isinstance(key, str):
raise TypeError(f"key must be an instance of str, not {type(key)}")
cmd = {"distinct": self._name, "key": key}
if filter is not None:
if "query" in kwargs:
raise ConfigurationError("can't pass both filter and query")
kwargs["query"] = filter
collation = validate_collation_or_none(kwargs.pop("collation", None))
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
if hint is not None:
if not isinstance(hint, str):
hint = helpers_shared._index_document(hint)
cmd["hint"] = hint # type: ignore[assignment]
async def _cmd(
session: Optional[AsyncClientSession],
@ -3218,6 +3239,12 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
conn: AsyncConnection,
read_preference: Optional[_ServerMode],
) -> list: # type: ignore[type-arg]
cmd = {"distinct": self._name, "key": key}
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
if hint is not None:
cmd["hint"] = hint # type: ignore[assignment]
return (
await self._command(
conn,
@ -3252,27 +3279,26 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
f"return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER, not {type(return_document)}"
)
collation = validate_collation_or_none(kwargs.pop("collation", None))
cmd = {"findAndModify": self._name, "query": filter, "new": return_document}
if let is not None:
common.validate_is_mapping("let", let)
cmd["let"] = let
cmd.update(kwargs)
if projection is not None:
cmd["fields"] = helpers_shared._fields_list_to_dict(projection, "projection")
if sort is not None:
cmd["sort"] = helpers_shared._index_document(sort)
if upsert is not None:
validate_boolean("upsert", upsert)
cmd["upsert"] = upsert
if hint is not None:
if not isinstance(hint, str):
hint = helpers_shared._index_document(hint)
write_concern = self._write_concern_for_cmd(cmd, session)
write_concern = self._write_concern_for_cmd(kwargs, session)
async def _find_and_modify_helper(
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
) -> Any:
cmd = {"findAndModify": self._name, "query": filter, "new": return_document}
if let is not None:
common.validate_is_mapping("let", let)
cmd["let"] = let
cmd.update(kwargs)
if projection is not None:
cmd["fields"] = helpers_shared._fields_list_to_dict(projection, "projection")
if sort is not None:
cmd["sort"] = helpers_shared._index_document(sort)
if upsert is not None:
validate_boolean("upsert", upsert)
cmd["upsert"] = upsert
acknowledged = write_concern.acknowledged
if array_filters is not None:
if not acknowledged:
@ -3321,7 +3347,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> _DocumentType:
) -> Optional[_DocumentType]:
"""Finds a single document and deletes it, returning the document.
>>> await db.test.count_documents({'x': 1})
@ -3331,6 +3357,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
>>> await db.test.count_documents({'x': 1})
1
Returns ``None`` if no document matches the filter.
>>> await db.test.find_one_and_delete({'_exists': False})
If multiple documents match *filter*, a *sort* can be applied.
>>> async for doc in db.test.find({'x': 1}):
@ -3413,10 +3443,22 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> _DocumentType:
) -> Optional[_DocumentType]:
"""Finds a single document and replaces it, returning either the
original or the replaced document.
>>> await db.test.find_one({'x': 1})
{'_id': 0, 'x': 1}
>>> await db.test.find_one_and_replace({'x': 1}, {'y': 2})
{'_id': 0, 'x': 1}
>>> await db.test.find_one({'x': 1})
>>> await db.test.find_one({'y': 2})
{'_id': 0, 'y': 2}
Returns ``None`` if no document matches the filter.
>>> await db.test.find_one_and_replace({'_exists': False}, {'x': 1})
The :meth:`find_one_and_replace` method differs from
:meth:`find_one_and_update` by replacing the document matched by
*filter*, rather than modifying the existing document.
@ -3521,13 +3563,17 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> _DocumentType:
) -> Optional[_DocumentType]:
"""Finds a single document and updates it, returning either the
original or the updated document.
>>> await db.test.find_one({'_id': 665})
{'_id': 665, 'done': False, 'count': 25}
>>> await db.test.find_one_and_update(
... {'_id': 665}, {'$inc': {'count': 1}, '$set': {'done': True}})
{'_id': 665, 'done': False, 'count': 25}}
{'_id': 665, 'done': False, 'count': 25}
>>> await db.test.find_one({'_id': 665})
{'_id': 665, 'done': True, 'count': 26}
Returns ``None`` if no document matches the filter.

View File

@ -20,7 +20,6 @@ from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Generic,
Mapping,
NoReturn,
Optional,
@ -29,17 +28,10 @@ from typing import (
)
from bson import CodecOptions, _convert_raw_document_lists_to_streams
from pymongo import _csot
from pymongo.asynchronous.cursor import _ConnectionManager
from pymongo.asynchronous.cursor_base import _AsyncCursorBase, _ConnectionManager
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.message import (
_CursorAddress,
_GetMore,
_OpMsg,
_OpReply,
_RawBatchGetMore,
)
from pymongo.message import _GetMore, _OpMsg, _OpReply, _RawBatchGetMore
from pymongo.response import PinnedResponse
from pymongo.typings import _Address, _DocumentOut, _DocumentType
@ -51,7 +43,7 @@ if TYPE_CHECKING:
_IS_SYNC = False
class AsyncCommandCursor(Generic[_DocumentType]):
class AsyncCommandCursor(_AsyncCursorBase[_DocumentType]):
"""An asynchronous cursor / iterator over command cursors."""
_getmore_class = _GetMore
@ -64,7 +56,6 @@ class AsyncCommandCursor(Generic[_DocumentType]):
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[AsyncClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new command cursor."""
@ -80,7 +71,8 @@ class AsyncCommandCursor(Generic[_DocumentType]):
self._max_await_time_ms = max_await_time_ms
self._timeout = self._collection.database.client.options.timeout
self._session = session
self._explicit_session = explicit_session
if self._session is not None:
self._session._attached_to_cursor = True
self._killed = self._id == 0
self._comment = comment
if self._killed:
@ -98,8 +90,8 @@ class AsyncCommandCursor(Generic[_DocumentType]):
f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}"
)
def __del__(self) -> None:
self._die_no_lock()
def _get_namespace(self) -> str:
return self._ns
def batch_size(self, batch_size: int) -> AsyncCommandCursor[_DocumentType]:
"""Limits the number of documents returned in one batch. Each batch
@ -161,92 +153,12 @@ class AsyncCommandCursor(Generic[_DocumentType]):
) -> Sequence[_DocumentOut]:
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
@property
def alive(self) -> bool:
"""Does this cursor have the potential to return more data?
Even if :attr:`alive` is ``True``, :meth:`next` can raise
:exc:`StopIteration`. Best to use a for loop::
async for doc in collection.aggregate(pipeline):
print(doc)
.. note:: :attr:`alive` can be True while iterating a cursor from
a failed server. In this case :attr:`alive` will return False after
:meth:`next` fails to retrieve the next batch of results from the
server.
"""
return bool(len(self._data) or (not self._killed))
@property
def cursor_id(self) -> int:
"""Returns the id of the cursor."""
return self._id
@property
def address(self) -> Optional[_Address]:
"""The (host, port) of the server used, or None.
.. versionadded:: 3.0
"""
return self._address
@property
def session(self) -> Optional[AsyncClientSession]:
"""The cursor's :class:`~pymongo.asynchronous.client_session.AsyncClientSession`, or None.
.. versionadded:: 3.6
"""
if self._explicit_session:
return self._session
return None
def _prepare_to_die(self) -> tuple[int, Optional[_CursorAddress]]:
already_killed = self._killed
self._killed = True
if self._id and not already_killed:
cursor_id = self._id
assert self._address is not None
address = _CursorAddress(self._address, self._ns)
else:
# Skip killCursors.
cursor_id = 0
address = None
return cursor_id, address
def _die_no_lock(self) -> None:
"""Closes this cursor without acquiring a lock."""
cursor_id, address = self._prepare_to_die()
self._collection.database.client._cleanup_cursor_no_lock(
cursor_id, address, self._sock_mgr, self._session, self._explicit_session
)
if not self._explicit_session:
self._session = None
self._sock_mgr = None
async def _die_lock(self) -> None:
"""Closes this cursor."""
cursor_id, address = self._prepare_to_die()
await self._collection.database.client._cleanup_cursor_lock(
cursor_id,
address,
self._sock_mgr,
self._session,
self._explicit_session,
)
if not self._explicit_session:
self._session = None
self._sock_mgr = None
def _end_session(self) -> None:
if self._session and not self._explicit_session:
if self._session and self._session._implicit:
self._session._attached_to_cursor = False
self._session._end_implicit_session()
self._session = None
async def close(self) -> None:
"""Explicitly close / kill this cursor."""
await self._die_lock()
async def _send_message(self, operation: _GetMore) -> None:
"""Send a getmore message and handle the response."""
client = self._collection.database.client
@ -328,6 +240,9 @@ class AsyncCommandCursor(Generic[_DocumentType]):
def __aiter__(self) -> AsyncIterator[_DocumentType]:
return self
async def __aenter__(self) -> AsyncCommandCursor[_DocumentType]:
return self
async def next(self) -> _DocumentType:
"""Advance the cursor."""
# Block until a document is returnable.
@ -383,41 +298,6 @@ class AsyncCommandCursor(Generic[_DocumentType]):
"""
return await self._try_next(get_more_allowed=True)
async def __aenter__(self) -> AsyncCommandCursor[_DocumentType]:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
@_csot.apply
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
To use::
>>> await cursor.to_list()
Or, so read at most n items from the cursor::
>>> await cursor.to_list(n)
If the cursor is empty or has no more results, an empty list will be returned.
.. versionadded:: 4.9
"""
res: list[_DocumentType] = []
remaining = length
if isinstance(length, int) and length < 1:
raise ValueError("to_list() length must be greater than 0")
while self.alive:
if not await self._next_batch(res, remaining):
break
if length is not None:
remaining = length - len(res)
if remaining == 0:
break
return res
class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):
_getmore_class = _RawBatchGetMore
@ -430,7 +310,6 @@ class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[AsyncClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new cursor / iterator over raw batches of BSON data.
@ -449,7 +328,6 @@ class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):
batch_size,
max_await_time_ms,
session,
explicit_session,
comment,
)

View File

@ -21,7 +21,6 @@ from collections import deque
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterable,
List,
Mapping,
@ -36,7 +35,8 @@ from typing import (
from bson import RE_TYPE, _convert_raw_document_lists_to_streams
from bson.code import Code
from bson.son import SON
from pymongo import _csot, helpers_shared
from pymongo import helpers_shared
from pymongo.asynchronous.cursor_base import _AsyncCursorBase, _ConnectionManager
from pymongo.asynchronous.helpers import anext
from pymongo.collation import validate_collation_or_none
from pymongo.common import (
@ -45,9 +45,7 @@ from pymongo.common import (
)
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.lock import _async_create_lock
from pymongo.message import (
_CursorAddress,
_GetMore,
_OpMsg,
_OpReply,
@ -65,31 +63,12 @@ if TYPE_CHECKING:
from bson.codec_options import CodecOptions
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.read_preferences import _ServerMode
_IS_SYNC = False
class _ConnectionManager:
"""Used with exhaust cursors to ensure the connection is returned."""
def __init__(self, conn: AsyncConnection, more_to_come: bool):
self.conn: Optional[AsyncConnection] = conn
self.more_to_come = more_to_come
self._lock = _async_create_lock()
def update_exhaust(self, more_to_come: bool) -> None:
self.more_to_come = more_to_come
async def close(self) -> None:
"""Return this instance's connection to the connection pool."""
if self.conn:
await self.conn.unpin()
self.conn = None
class AsyncCursor(Generic[_DocumentType]):
class AsyncCursor(_AsyncCursorBase[_DocumentType]):
_query_class = _Query
_getmore_class = _GetMore
@ -138,10 +117,9 @@ class AsyncCursor(Generic[_DocumentType]):
if session:
self._session = session
self._explicit_session = True
self._session._attached_to_cursor = True
else:
self._session = None
self._explicit_session = False
spec: Mapping[str, Any] = filter or {}
validate_is_mapping("filter", spec)
@ -150,7 +128,7 @@ class AsyncCursor(Generic[_DocumentType]):
if not isinstance(limit, int):
raise TypeError(f"limit must be an instance of int, not {type(limit)}")
validate_boolean("no_cursor_timeout", no_cursor_timeout)
if no_cursor_timeout and not self._explicit_session:
if no_cursor_timeout and self._session and self._session._implicit:
warnings.warn(
"use an explicit session with no_cursor_timeout=True "
"otherwise the cursor may still timeout after "
@ -267,8 +245,8 @@ class AsyncCursor(Generic[_DocumentType]):
"""The number of documents retrieved so far."""
return self._retrieved
def __del__(self) -> None:
self._die_no_lock()
def _get_namespace(self) -> str:
return f"{self._dbname}.{self._collname}"
def clone(self) -> AsyncCursor[_DocumentType]:
"""Get a clone of this cursor.
@ -283,7 +261,7 @@ class AsyncCursor(Generic[_DocumentType]):
def _clone(self, deepcopy: bool = True, base: Optional[AsyncCursor] = None) -> AsyncCursor: # type: ignore[type-arg]
"""Internal clone helper."""
if not base:
if self._explicit_session:
if self._session and not self._session._implicit:
base = self._clone_base(self._session)
else:
base = self._clone_base(None)
@ -900,55 +878,6 @@ class AsyncCursor(Generic[_DocumentType]):
self._read_preference = self._collection._read_preference_for(self.session)
return self._read_preference
@property
def alive(self) -> bool:
"""Does this cursor have the potential to return more data?
This is mostly useful with `tailable cursors
<https://www.mongodb.com/docs/manual/core/tailable-cursors/>`_
since they will stop iterating even though they *may* return more
results in the future.
With regular cursors, simply use an asynchronous for loop instead of :attr:`alive`::
async for doc in collection.find():
print(doc)
.. note:: Even if :attr:`alive` is True, :meth:`next` can raise
:exc:`StopIteration`. :attr:`alive` can also be True while iterating
a cursor from a failed server. In this case :attr:`alive` will
return False after :meth:`next` fails to retrieve the next batch
of results from the server.
"""
return bool(len(self._data) or (not self._killed))
@property
def cursor_id(self) -> Optional[int]:
"""Returns the id of the cursor
.. versionadded:: 2.2
"""
return self._id
@property
def address(self) -> Optional[tuple[str, Any]]:
"""The (host, port) of the server used, or None.
.. versionchanged:: 3.0
Renamed from "conn_id".
"""
return self._address
@property
def session(self) -> Optional[AsyncClientSession]:
"""The cursor's :class:`~pymongo.asynchronous.client_session.AsyncClientSession`, or None.
.. versionadded:: 3.6
"""
if self._explicit_session:
return self._session
return None
def __copy__(self) -> AsyncCursor[_DocumentType]:
"""Support function for `copy.copy()`.
@ -1009,62 +938,10 @@ class AsyncCursor(Generic[_DocumentType]):
else:
if not isinstance(key, RE_TYPE):
key = copy.deepcopy(key, memo) # noqa: PLW2901
y[key] = value
y[key] = value # type:ignore[index]
return y
def _prepare_to_die(self, already_killed: bool) -> tuple[int, Optional[_CursorAddress]]:
self._killed = True
if self._id and not already_killed:
cursor_id = self._id
assert self._address is not None
address = _CursorAddress(self._address, f"{self._dbname}.{self._collname}")
else:
# Skip killCursors.
cursor_id = 0
address = None
return cursor_id, address
def _die_no_lock(self) -> None:
"""Closes this cursor without acquiring a lock."""
try:
already_killed = self._killed
except AttributeError:
# ___init__ did not run to completion (or at all).
return
cursor_id, address = self._prepare_to_die(already_killed)
self._collection.database.client._cleanup_cursor_no_lock(
cursor_id, address, self._sock_mgr, self._session, self._explicit_session
)
if not self._explicit_session:
self._session = None
self._sock_mgr = None
async def _die_lock(self) -> None:
"""Closes this cursor."""
try:
already_killed = self._killed
except AttributeError:
# ___init__ did not run to completion (or at all).
return
cursor_id, address = self._prepare_to_die(already_killed)
await self._collection.database.client._cleanup_cursor_lock(
cursor_id,
address,
self._sock_mgr,
self._session,
self._explicit_session,
)
if not self._explicit_session:
self._session = None
self._sock_mgr = None
async def close(self) -> None:
"""Explicitly close / kill this cursor."""
await self._die_lock()
async def distinct(self, key: str) -> list[str]:
async def distinct(self, key: str) -> list[Any]:
"""Get a list of distinct values for `key` among all documents
in the result set of this query.
@ -1296,40 +1173,8 @@ class AsyncCursor(Generic[_DocumentType]):
async def __aenter__(self) -> AsyncCursor[_DocumentType]:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
@_csot.apply
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
To use::
>>> await cursor.to_list()
Or, to read at most n items from the cursor::
>>> await cursor.to_list(n)
If the cursor is empty or has no more results, an empty list will be returned.
.. versionadded:: 4.9
"""
res: list[_DocumentType] = []
remaining = length
if isinstance(length, int) and length < 1:
raise ValueError("to_list() length must be greater than 0")
while self.alive:
if not await self._next_batch(res, remaining):
break
if length is not None:
remaining = length - len(res)
if remaining == 0:
break
return res
class AsyncRawBatchCursor(AsyncCursor, Generic[_DocumentType]): # type: ignore[type-arg]
class AsyncRawBatchCursor(AsyncCursor[_DocumentType]):
"""An asynchronous cursor / iterator over raw batches of BSON data from a query result."""
_query_class = _RawBatchQuery

View File

@ -0,0 +1,122 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Asynchronous cursor base extending the shared agnostic cursor base."""
from __future__ import annotations
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from pymongo import _csot
from pymongo.cursor_shared import _AgnosticCursorBase
from pymongo.lock import _async_create_lock
from pymongo.typings import _DocumentType
if TYPE_CHECKING:
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.pool import AsyncConnection
_IS_SYNC = False
class _ConnectionManager:
"""Used with exhaust cursors to ensure the connection is returned."""
def __init__(self, conn: AsyncConnection, more_to_come: bool):
self.conn: Optional[AsyncConnection] = conn
self.more_to_come = more_to_come
self._lock = _async_create_lock()
def update_exhaust(self, more_to_come: bool) -> None:
self.more_to_come = more_to_come
async def close(self) -> None:
"""Return this instance's connection to the connection pool."""
if self.conn:
await self.conn.unpin()
self.conn = None
class _AsyncCursorBase(_AgnosticCursorBase[_DocumentType]):
"""Asynchronous cursor base class."""
@property
def session(self) -> Optional[AsyncClientSession]:
"""The cursor's :class:`~pymongo.asynchronous.client_session.AsyncClientSession`, or None.
.. versionadded:: 3.6
"""
if self._session and not self._session._implicit:
return self._session
return None
@abstractmethod
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg]
...
async def _die_lock(self) -> None:
"""Closes this cursor."""
try:
already_killed = self._killed
except AttributeError:
# ___init__ did not run to completion (or at all).
return
cursor_id, address = self._prepare_to_die(already_killed)
await self._collection.database.client._cleanup_cursor_lock(
cursor_id,
address,
self._sock_mgr,
self._session,
)
if self._session and self._session._implicit:
self._session._attached_to_cursor = False
self._session = None
self._sock_mgr = None
async def close(self) -> None:
"""Explicitly close / kill this cursor."""
await self._die_lock()
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
@_csot.apply
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
To use::
>>> await cursor.to_list()
Or, to read at most n items from the cursor::
>>> await cursor.to_list(n)
If the cursor is empty or has no more results, an empty list will be returned.
.. versionadded:: 4.9
"""
res: list[_DocumentType] = []
remaining = length
if isinstance(length, int) and length < 1:
raise ValueError("to_list() length must be greater than 0")
while self.alive:
if not await self._next_batch(res, remaining):
break
if length is not None:
remaining = length - len(res)
if remaining == 0:
break
return res

View File

@ -614,6 +614,8 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
common.validate_is_mapping("clusteredIndex", clustered_index)
async with self._client._tmp_session(session) as s:
if s and not s.in_transaction:
s._leave_alive = True
# Skip this check in a transaction where listCollections is not
# supported.
if (
@ -622,6 +624,8 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
and name in await self._list_collection_names(filter={"name": name}, session=s)
):
raise CollectionInvalid("collection %s already exists" % name)
if s:
s._leave_alive = False
coll = AsyncCollection(
self,
name,
@ -697,18 +701,17 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 3.9
.. _aggregation pipeline:
https://mongodb.com/docs/manual/reference/operator/aggregation-pipeline
https://www.mongodb.com/docs/manual/core/aggregation-pipeline/
.. _aggregate command:
https://mongodb.com/docs/manual/reference/command/aggregate
"""
async with self.client._tmp_session(session, close=False) as s:
async with self.client._tmp_session(session) as s:
cmd = _DatabaseAggregationCommand(
self,
AsyncCommandCursor,
pipeline,
kwargs,
session is not None,
user_fields={"cursor": {"firstBatch": 1}},
)
return await self.client._retryable_read(
@ -932,14 +935,15 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
if read_preference is None:
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
async with await self._client._conn_for_reads(
read_preference, session, operation=command_name
) as (
connection,
read_preference,
):
async def inner(
session: Optional[AsyncClientSession],
_server: Server,
conn: AsyncConnection,
read_preference: _ServerMode,
) -> Union[dict[str, Any], _CodecDocumentType]:
return await self._command(
connection,
conn,
command,
value,
check,
@ -950,6 +954,10 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
**kwargs,
)
return await self._client._retryable_read(
inner, read_preference, session, command_name, None, False, is_run_command=True
)
@_csot.apply
@_retry_overload
async def cursor_command(
@ -1016,19 +1024,19 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
else:
command_name = next(iter(command))
async with self._client._tmp_session(session, close=False) as tmp_session:
async with self._client._tmp_session(session) as tmp_session:
opts = codec_options or DEFAULT_CODEC_OPTIONS
if read_preference is None:
read_preference = (
tmp_session and tmp_session._txn_read_preference()
) or ReadPreference.PRIMARY
async with await self._client._conn_for_reads(
read_preference, tmp_session, command_name
) as (
conn,
read_preference,
):
async def inner(
session: Optional[AsyncClientSession],
_server: Server,
conn: AsyncConnection,
read_preference: _ServerMode,
) -> AsyncCommandCursor[_DocumentType]:
response = await self._command(
conn,
command,
@ -1037,7 +1045,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
None,
read_preference,
opts,
session=tmp_session,
session=session,
**kwargs,
)
coll = self.get_collection("$cmd", read_preference=read_preference)
@ -1047,8 +1055,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
response["cursor"],
conn.address,
max_await_time_ms=max_await_time_ms,
session=tmp_session,
explicit_session=session is not None,
session=session,
comment=comment,
)
await cmd_cursor._maybe_pin_connection(conn)
@ -1056,6 +1063,10 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
else:
raise InvalidOperation("Command does not return a cursor.")
return await self.client._retryable_read(
inner, read_preference, tmp_session, command_name, None, False
)
async def _retryable_read_command(
self,
command: Union[str, MutableMapping[str, Any]],
@ -1094,7 +1105,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
)
cmd = {"listCollections": 1, "cursor": {}}
cmd.update(kwargs)
async with self._client._tmp_session(session, close=False) as tmp_session:
async with self._client._tmp_session(session) as tmp_session:
cursor = (
await self._command(conn, cmd, read_preference=read_preference, session=tmp_session)
)["cursor"]
@ -1103,7 +1114,6 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
cursor,
conn.address,
session=tmp_session,
explicit_session=session is not None,
comment=cmd.get("comment"),
)
await cmd_cursor._maybe_pin_connection(conn)
@ -1258,9 +1268,11 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
command["comment"] = comment
async with await self._client._conn_for_writes(session, operation=_Op.DROP) as connection:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> dict[str, Any]:
return await self._command(
connection,
conn,
command,
allowable_errors=["ns not found", 26],
write_concern=self._write_concern_for(session),
@ -1268,6 +1280,8 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
session=session,
)
return await self.client._retryable_write(False, inner, session, _Op.DROP)
@_csot.apply
@_retry_overload
async def drop_collection(

View File

@ -64,10 +64,14 @@ from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import AsyncBaseConnection
from pymongo.common import CONNECT_TIMEOUT
from pymongo.daemon import _spawn_daemon
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
from pymongo.encryption_options import (
AutoEncryptionOpts,
RangeOpts,
TextOpts,
check_min_pymongocrypt,
)
from pymongo.errors import (
ConfigurationError,
EncryptedCollectionError,
@ -77,11 +81,11 @@ from pymongo.errors import (
ServerSelectionTimeoutError,
)
from pymongo.helpers_shared import _get_timeout_details
from pymongo.network_layer import PyMongoKMSProtocol, async_receive_kms, async_sendall
from pymongo.network_layer import async_socket_sendall
from pymongo.operations import UpdateOne
from pymongo.pool_options import PoolOptions
from pymongo.pool_shared import (
_configured_protocol_interface,
_async_configured_socket,
_raise_connection_failure,
)
from pymongo.read_concern import ReadConcern
@ -94,8 +98,10 @@ from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from pymongocrypt.mongocrypt import MongoCryptKmsContext
from pymongo.pyopenssl_context import _sslConn
from pymongo.typings import _Address
_IS_SYNC = False
_HTTPS_PORT = 443
@ -110,10 +116,9 @@ _DATA_KEY_OPTS: CodecOptions[dict[str, Any]] = CodecOptions(
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
async def _connect_kms(address: _Address, opts: PoolOptions) -> AsyncBaseConnection:
async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
try:
interface = await _configured_protocol_interface(address, opts, PyMongoKMSProtocol)
return AsyncBaseConnection(interface, opts)
return await _async_configured_socket(address, opts)
except Exception as exc:
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
@ -198,11 +203,19 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
try:
conn = await _connect_kms(address, opts)
try:
await async_sendall(conn.conn.get_conn, message)
await async_socket_sendall(conn, message)
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.set_conn_timeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = await async_receive_kms(conn, kms_context.bytes_needed)
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data: memoryview | bytes
if _IS_SYNC:
data = conn.recv(kms_context.bytes_needed)
else:
from pymongo.network_layer import ( # type: ignore[attr-defined]
async_receive_data_socket,
)
data = await async_receive_data_socket(conn, kms_context.bytes_needed)
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
@ -221,7 +234,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts)
)
finally:
await conn.close_conn(None)
conn.close()
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:
@ -264,7 +277,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
args.extend(self.opts._mongocryptd_spawn_args)
_spawn_daemon(args)
async def mark_command(self, database: str, cmd: bytes) -> bytes:
async def mark_command(self, database: str, cmd: bytes) -> bytes | memoryview:
"""Mark a command for encryption.
:param database: The database on which to run this command.
@ -291,7 +304,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
)
return res.raw
async def fetch_keys(self, filter: bytes) -> AsyncGenerator[bytes, None]:
async def fetch_keys(self, filter: bytes) -> AsyncGenerator[bytes | memoryview, None]:
"""Yields one or more keys from the key vault.
:param filter: The filter to pass to find.
@ -463,7 +476,7 @@ class _Encrypter:
# TODO: PYTHON-1922 avoid decoding the encrypted_cmd.
return _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)
async def decrypt(self, response: bytes) -> Optional[bytes]:
async def decrypt(self, response: bytes | memoryview) -> Optional[bytes]:
"""Decrypt a MongoDB command response.
:param response: A MongoDB command response as BSON.
@ -516,6 +529,11 @@ class Algorithm(str, enum.Enum):
.. versionadded:: 4.4
"""
TEXTPREVIEW = "TextPreview"
"""**BETA** - TextPreview.
.. versionadded:: 4.15
"""
class QueryType(str, enum.Enum):
@ -541,6 +559,24 @@ class QueryType(str, enum.Enum):
.. versionadded:: 4.4
"""
PREFIXPREVIEW = "prefixPreview"
"""**BETA** - Used to encrypt a value for a prefixPreview query.
.. versionadded:: 4.15
"""
SUFFIXPREVIEW = "suffixPreview"
"""**BETA** - Used to encrypt a value for a suffixPreview query.
.. versionadded:: 4.15
"""
SUBSTRINGPREVIEW = "substringPreview"
"""**BETA** - Used to encrypt a value for a substringPreview query.
.. versionadded:: 4.15
"""
def _create_mongocrypt_options(**kwargs: Any) -> MongoCryptOptions:
# For compat with pymongocrypt <1.13, avoid setting the default key_expiration_ms.
@ -644,6 +680,8 @@ class AsyncClientEncryption(Generic[_DocumentType]):
"python -m pip install --upgrade 'pymongo[encryption]'"
)
check_min_pymongocrypt()
if not isinstance(codec_options, CodecOptions):
raise TypeError(
f"codec_options must be an instance of bson.codec_options.CodecOptions, not {type(codec_options)}"
@ -679,7 +717,10 @@ class AsyncClientEncryption(Generic[_DocumentType]):
self._encryption = AsyncExplicitEncrypter(
self._io_callbacks,
_create_mongocrypt_options(
kms_providers=kms_providers, schema_map=None, key_expiration_ms=key_expiration_ms
kms_providers=kms_providers,
schema_map=None,
key_expiration_ms=key_expiration_ms,
bypass_encryption=True, # Don't load crypt_shared
),
)
# Use the same key vault collection as the callback.
@ -876,6 +917,7 @@ class AsyncClientEncryption(Generic[_DocumentType]):
contention_factor: Optional[int] = None,
range_opts: Optional[RangeOpts] = None,
is_expression: bool = False,
text_opts: Optional[TextOpts] = None,
) -> Any:
self._check_closed()
if isinstance(key_id, uuid.UUID):
@ -895,6 +937,12 @@ class AsyncClientEncryption(Generic[_DocumentType]):
range_opts.document,
codec_options=self._codec_options,
)
text_opts_bytes = None
if text_opts:
text_opts_bytes = encode(
text_opts.document,
codec_options=self._codec_options,
)
with _wrap_encryption_errors():
encrypted_doc = await self._encryption.encrypt(
value=doc,
@ -905,6 +953,8 @@ class AsyncClientEncryption(Generic[_DocumentType]):
contention_factor=contention_factor,
range_opts=range_opts_bytes,
is_expression=is_expression,
# For compatibility with pymongocrypt < 1.16:
**{"text_opts": text_opts_bytes} if text_opts_bytes else {},
)
return decode(encrypted_doc)["v"]
@ -917,6 +967,7 @@ class AsyncClientEncryption(Generic[_DocumentType]):
query_type: Optional[str] = None,
contention_factor: Optional[int] = None,
range_opts: Optional[RangeOpts] = None,
text_opts: Optional[TextOpts] = None,
) -> Binary:
"""Encrypt a BSON value with a given key and algorithm.
@ -937,9 +988,14 @@ class AsyncClientEncryption(Generic[_DocumentType]):
used.
:param range_opts: Index options for `range` queries. See
:class:`RangeOpts` for some valid options.
:param text_opts: Index options for `textPreview` queries. See
:class:`TextOpts` for some valid options.
:return: The encrypted value, a :class:`~bson.binary.Binary` with subtype 6.
.. versionchanged:: 4.9
Added the `text_opts` parameter.
.. versionchanged:: 4.9
Added the `range_opts` parameter.
@ -960,6 +1016,7 @@ class AsyncClientEncryption(Generic[_DocumentType]):
contention_factor=contention_factor,
range_opts=range_opts,
is_expression=False,
text_opts=text_opts,
),
)

View File

@ -32,7 +32,6 @@ from typing import (
from pymongo import _csot
from pymongo.errors import (
OperationFailure,
PyMongoError,
)
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
from pymongo.lock import _async_create_lock
@ -80,7 +79,6 @@ def _handle_reauth(func: F) -> F:
_MAX_RETRIES = 5
_BACKOFF_INITIAL = 0.1
_BACKOFF_MAX = 10
# DRIVERS-3240 will determine these defaults.
DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0
DEFAULT_RETRY_TOKEN_RETURN = 0.1
@ -102,7 +100,6 @@ class _TokenBucket:
):
self.lock = _async_create_lock()
self.capacity = capacity
# DRIVERS-3240 will determine how full the bucket should start.
self.tokens = capacity
self.return_rate = return_rate
@ -124,7 +121,7 @@ class _TokenBucket:
class _RetryPolicy:
"""A retry limiter that performs exponential backoff with jitter.
Retry attempts are limited by a token bucket to prevent overwhelming the server during
When adaptive retries are enabled, retry attempts are limited by a token bucket to prevent overwhelming the server during
a prolonged outage or high load.
"""
@ -134,15 +131,18 @@ class _RetryPolicy:
attempts: int = _MAX_RETRIES,
backoff_initial: float = _BACKOFF_INITIAL,
backoff_max: float = _BACKOFF_MAX,
adaptive_retry: bool = False,
):
self.token_bucket = token_bucket
self.attempts = attempts
self.backoff_initial = backoff_initial
self.backoff_max = backoff_max
self.adaptive_retry = adaptive_retry
async def record_success(self, retry: bool) -> None:
"""Record a successful operation."""
await self.token_bucket.deposit(retry)
if self.adaptive_retry:
await self.token_bucket.deposit(retry)
def backoff(self, attempt: int) -> float:
"""Return the backoff duration for the given ."""
@ -159,41 +159,13 @@ class _RetryPolicy:
return False
# Check token bucket last since we only want to consume a token if we actually retry.
if not await self.token_bucket.consume():
if self.adaptive_retry and not await self.token_bucket.consume():
# DRIVERS-3246 Improve diagnostics when this case happens.
# We could add info to the exception and log.
return False
return True
def _retry_overload(func: F) -> F:
@functools.wraps(func)
async def inner(self: Any, *args: Any, **kwargs: Any) -> Any:
retry_policy = self._retry_policy
attempt = 0
while True:
try:
res = await func(self, *args, **kwargs)
await retry_policy.record_success(retry=attempt > 0)
return res
except PyMongoError as exc:
if not exc.has_error_label("RetryableError"):
raise
attempt += 1
delay = 0
if exc.has_error_label("SystemOverloadedError"):
delay = retry_policy.backoff(attempt)
if not await retry_policy.should_retry(attempt, delay):
raise
# Implement exponential backoff on retry.
if delay:
await asyncio.sleep(delay)
continue
return cast(F, inner)
async def _getaddrinfo(
host: Any, port: Any, **kwargs: Any
) -> list[
@ -202,7 +174,7 @@ async def _getaddrinfo(
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes],
]
]:
if not _IS_SYNC:

View File

@ -66,10 +66,9 @@ from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.asynchronous import client_session, database, uri_parser
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _EmptyServerSession
from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.helpers import (
_retry_overload,
_RetryPolicy,
_TokenBucket,
)
@ -145,7 +144,7 @@ if TYPE_CHECKING:
from bson.objectid import ObjectId
from pymongo.asynchronous.bulk import _AsyncBulk
from pymongo.asynchronous.client_session import AsyncClientSession, _ServerSession
from pymongo.asynchronous.cursor import _ConnectionManager
from pymongo.asynchronous.cursor_base import _ConnectionManager
from pymongo.asynchronous.encryption import _Encrypter
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.asynchronous.server import Server
@ -428,8 +427,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
with the server. Currently supported options are "snappy", "zlib"
and "zstd". Support for snappy requires the
`python-snappy <https://pypi.org/project/python-snappy/>`_ package.
zlib support requires the Python standard library zlib module. zstd
requires the `zstandard <https://pypi.org/project/zstandard/>`_
zlib support requires the Python standard library zlib module. For
Python before 3.14 zstd requires the `backports.zstd <https://pypi.org/project/backports.zstd/>`_
package. By default no compression is used. Compression support
must also be enabled on the server. MongoDB 3.6+ supports snappy
and zlib compression. MongoDB 4.2+ adds support for zstd.
@ -616,8 +615,18 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
client to use Stable API. See `versioned API <https://www.mongodb.com/docs/manual/reference/stable-api/#what-is-the-stable-api--and-should-you-use-it->`_ for
details.
| **Adaptive retry options:**
| (If not enabled explicitly, adaptive retries will not be enabled.)
- `adaptive_retries`: (boolean) Whether the adaptive retry mechanism is enabled for this client.
If enabled, server overload errors will use a token-bucket based system to mitigate further overload.
Defaults to ``False``.
.. seealso:: The MongoDB documentation on `connections <https://dochub.mongodb.org/core/connections>`_.
.. versionchanged:: 4.17
Added the ``adaptive_retries`` URI and keyword argument.
.. versionchanged:: 4.5
Added the ``serverMonitoringMode`` keyword argument.
@ -779,7 +788,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self._timeout: float | None = None
self._topology_settings: TopologySettings = None # type: ignore[assignment]
self._event_listeners: _EventListeners | None = None
self._retry_policy = _RetryPolicy(_TokenBucket())
# _pool_class, _monitor_class, and _condition_class are for deep
# customization of PyMongo, e.g. Motor.
@ -886,11 +894,16 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self._options.read_concern,
)
self._retry_policy = _RetryPolicy(
_TokenBucket(), adaptive_retry=self._options.adaptive_retries
)
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
self._opened = False
self._closed = False
self._loop: Optional[asyncio.AbstractEventLoop] = None
if not is_srv:
self._init_background()
@ -1415,7 +1428,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
def _ensure_session(
self, session: Optional[AsyncClientSession] = None
) -> Optional[AsyncClientSession]:
"""If provided session is None, lend a temporary session."""
"""If provided session and bound session are None, lend a temporary session."""
session = session or self._get_bound_session()
if session:
return session
@ -1997,6 +2011,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
read_pref: Optional[_ServerMode] = None,
retryable: bool = False,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
) -> T:
"""Internal retryable helper for all client transactions.
@ -2008,6 +2024,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: Server Address, defaults to None
:param read_pref: Topology of read operation, defaults to None
:param retryable: If the operation should be retried once, defaults to None
:param is_run_command: If this is a runCommand operation, defaults to False
:param is_aggregate_write: If this is a aggregate operation with a write, defaults to False.
:return: Output of the calling func()
"""
@ -2022,6 +2040,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address=address,
retryable=retryable,
operation_id=operation_id,
is_run_command=is_run_command,
is_aggregate_write=is_aggregate_write,
).run()
async def _retryable_read(
@ -2033,6 +2053,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address: Optional[_Address] = None,
retryable: bool = True,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
) -> T:
"""Execute an operation with consecutive retries if possible
@ -2048,6 +2070,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: Optional address when sending a message, defaults to None
:param retryable: if we should attempt retries
(may not always be supported even if supplied), defaults to False
:param is_run_command: If this is a runCommand operation, defaults to False.
:param is_aggregate_write: If this is a aggregate operation with a write, defaults to False.
"""
# Ensure that the client supports retrying on reads and there is no session in
@ -2055,17 +2079,20 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
retryable = bool(
retryable and self.options.retry_reads and not (session and session.in_transaction)
)
return await self._retry_internal(
func,
session,
None,
operation,
is_read=True,
address=address,
read_pref=read_pref,
retryable=retryable,
operation_id=operation_id,
)
async with self._tmp_session(session) as s:
return await self._retry_internal(
func,
s,
None,
operation,
is_read=True,
address=address,
read_pref=read_pref,
retryable=retryable,
operation_id=operation_id,
is_run_command=is_run_command,
is_aggregate_write=is_aggregate_write,
)
async def _retryable_write(
self,
@ -2098,7 +2125,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address: Optional[_CursorAddress],
conn_mgr: _ConnectionManager,
session: Optional[AsyncClientSession],
explicit_session: bool,
) -> None:
"""Cleanup a cursor from __del__ without locking.
@ -2113,7 +2139,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
# The cursor will be closed later in a different session.
if cursor_id or conn_mgr:
self._close_cursor_soon(cursor_id, address, conn_mgr)
if session and not explicit_session:
if session and session._implicit and not session._leave_alive:
session._end_implicit_session()
async def _cleanup_cursor_lock(
@ -2122,7 +2148,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address: Optional[_CursorAddress],
conn_mgr: _ConnectionManager,
session: Optional[AsyncClientSession],
explicit_session: bool,
) -> None:
"""Cleanup a cursor from cursor.close() using a lock.
@ -2134,7 +2159,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: The _CursorAddress.
:param conn_mgr: The _ConnectionManager for the pinned connection or None.
:param session: The cursor's session.
:param explicit_session: True if the session was passed explicitly.
"""
if cursor_id:
if conn_mgr and conn_mgr.more_to_come:
@ -2147,7 +2171,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
await self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr)
if conn_mgr:
await conn_mgr.close()
if session and not explicit_session:
if session and session._implicit and not session._leave_alive:
session._end_implicit_session()
async def _close_cursor_now(
@ -2228,7 +2252,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_id, conn_mgr in pinned_cursors:
try:
await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None)
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
# Raise the exception when client is closed so that it
@ -2273,14 +2297,17 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
@contextlib.asynccontextmanager
async def _tmp_session(
self, session: Optional[client_session.AsyncClientSession], close: bool = True
self, session: Optional[client_session.AsyncClientSession]
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]:
"""If provided session is None, lend a temporary session."""
if session is not None:
if not isinstance(session, client_session.AsyncClientSession):
raise ValueError(
f"'session' argument must be an AsyncClientSession or None, not {type(session)}"
)
if session is not None and not isinstance(session, client_session.AsyncClientSession):
raise ValueError(
f"'session' argument must be an AsyncClientSession or None, not {type(session)}"
)
# Check for a bound session. If one exists, treat it as an explicitly passed session.
session = session or self._get_bound_session()
if session:
# Don't call end_session.
yield session
return
@ -2298,7 +2325,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
raise
finally:
# Call end_session when we exit this scope.
if close:
if not s._attached_to_cursor:
await s.end_session()
else:
yield None
@ -2310,6 +2337,18 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
if session is not None:
session._process_response(reply)
def _get_bound_session(self) -> Optional[AsyncClientSession]:
bound_session = _SESSION.get()
if bound_session:
if bound_session.client is self:
return bound_session
else:
raise InvalidOperation(
"Only the client that created the bound session can perform operations within its context block. See <PLACEHOLDER> for more information."
)
else:
return None
async def server_info(
self, session: Optional[client_session.AsyncClientSession] = None
) -> dict[str, Any]:
@ -2405,7 +2444,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
return [doc["name"] async for doc in res]
@_csot.apply
@_retry_overload
async def drop_database(
self,
name_or_database: Union[str, database.AsyncDatabase[_DocumentTypeArg]],
@ -2448,15 +2486,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
f"name_or_database must be an instance of str or a AsyncDatabase, not {type(name)}"
)
async with await self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn:
await self[name]._command(
conn,
{"dropDatabase": 1, "comment": comment},
read_preference=ReadPreference.PRIMARY,
write_concern=self._write_concern_for(session),
parse_write_concern_error=True,
session=session,
)
await self[name].command(
{"dropDatabase": 1, "comment": comment},
read_preference=ReadPreference.PRIMARY,
write_concern=self._write_concern_for(session),
parse_write_concern_error=True,
session=session,
)
@_csot.apply
async def bulk_write(
@ -2740,6 +2776,8 @@ class _ClientConnectionRetryable(Generic[T]):
address: Optional[_Address] = None,
retryable: bool = False,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
):
self._last_error: Optional[Exception] = None
self._retrying = False
@ -2762,6 +2800,8 @@ class _ClientConnectionRetryable(Generic[T]):
self._operation = operation
self._operation_id = operation_id
self._attempt_number = 0
self._is_run_command = is_run_command
self._is_aggregate_write = is_aggregate_write
async def run(self) -> T:
"""Runs the supplied func() and attempts a retry
@ -2783,6 +2823,11 @@ class _ClientConnectionRetryable(Generic[T]):
try:
res = await self._read() if self._is_read else await self._write()
await self._retry_policy.record_success(self._attempt_number > 0)
# Track whether the transaction has completed a command.
# If we need to apply backpressure to the first command,
# we will need to revert back to starting state.
if self._session is not None and self._session.in_transaction:
self._session._transaction.has_completed_command = True
return res
except ServerSelectionTimeoutError:
# The application may think the write was never attempted
@ -2797,24 +2842,48 @@ class _ClientConnectionRetryable(Generic[T]):
always_retryable = False
overloaded = False
exc_to_check = exc
if self._is_run_command and not (
self._client.options.retry_reads and self._client.options.retry_writes
):
raise
if self._is_aggregate_write and not self._client.options.retry_writes:
raise
# Execute specialized catch on read
if self._is_read:
if isinstance(exc, (ConnectionFailure, OperationFailure)):
# ConnectionFailures do not supply a code property
exc_code = getattr(exc, "code", None)
always_retryable = exc.has_error_label("RetryableError")
overloaded = exc.has_error_label("SystemOverloadedError")
if not always_retryable and (
self._is_not_eligible_for_retry()
or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
always_retryable = exc.has_error_label("RetryableError") and overloaded
if (
not self._client.options.retry_reads
or not always_retryable
and (
self._is_not_eligible_for_retry()
or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
)
)
):
raise
self._retrying = True
self._last_error = exc
self._attempt_number += 1
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
if (
overloaded
and self._session is not None
and self._session.in_transaction
):
transaction = self._session._transaction
if not transaction.has_completed_command:
transaction.set_starting()
transaction.attempt = 0
else:
raise
@ -2825,9 +2894,14 @@ class _ClientConnectionRetryable(Generic[T]):
):
exc_to_check = exc.error
retryable_write_label = exc_to_check.has_error_label("RetryableWriteError")
always_retryable = exc_to_check.has_error_label("RetryableError")
overloaded = exc_to_check.has_error_label("SystemOverloadedError")
if not self._retryable and not always_retryable:
always_retryable = exc_to_check.has_error_label("RetryableError") and overloaded
# Always retry abortTransaction and commitTransaction up to once
if self._operation not in ["abortTransaction", "commitTransaction"] and (
not self._client.options.retry_writes
or not (self._retryable or always_retryable)
):
raise
if retryable_write_label or always_retryable:
assert self._session
@ -2848,20 +2922,30 @@ class _ClientConnectionRetryable(Generic[T]):
self._last_error = exc
if self._last_error is None:
self._last_error = exc
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
if overloaded and self._session is not None and self._session.in_transaction:
transaction = self._session._transaction
if not transaction.has_completed_command:
transaction.set_starting()
transaction.attempt = 0
if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
if (
self._server is not None
and self._client.topology_description.topology_type_name == "Sharded"
or exc.has_error_label("SystemOverloadedError")
):
self._deprioritized_servers.append(self._server)
self._always_retryable = always_retryable
if always_retryable:
if overloaded:
delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0
if not await self._retry_policy.should_retry(self._attempt_number, delay):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
if overloaded:
await asyncio.sleep(delay)
await asyncio.sleep(delay)
def _is_not_eligible_for_retry(self) -> bool:
"""Checks if the exchange is not eligible for retry"""

View File

@ -19,6 +19,8 @@ import collections
import contextlib
import logging
import os
import socket
import ssl
import sys
import time
import weakref
@ -37,7 +39,7 @@ from typing import (
from bson import DEFAULT_CODEC_OPTIONS
from pymongo import _csot, helpers_shared
from pymongo.asynchronous.client_session import _validate_session_write_concern
from pymongo.asynchronous.helpers import _backoff, _handle_reauth
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.asynchronous.network import command
from pymongo.common import (
MAX_BSON_SIZE,
@ -52,10 +54,12 @@ from pymongo.errors import ( # type:ignore[attr-defined]
DocumentTooLarge,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
NotPrimaryError,
OperationFailure,
PyMongoError,
WaitQueueTimeoutError,
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.helpers_shared import _get_timeout_details, format_timeout_details
@ -104,38 +108,78 @@ if TYPE_CHECKING:
from pymongo.typings import _Address, _CollationIn
from pymongo.write_concern import WriteConcern
try:
from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl
def _set_non_inheritable_non_atomic(fd: int) -> None:
"""Set the close-on-exec flag on the given file descriptor."""
flags = fcntl(fd, F_GETFD)
fcntl(fd, F_SETFD, flags | FD_CLOEXEC)
except ImportError:
# Windows, various platforms we don't claim to support
# (Jython, IronPython, ..), systems that don't provide
# everything we need from fcntl, etc.
def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
"""Dummy function for platforms that don't provide fcntl."""
_IS_SYNC = False
class AsyncBaseConnection:
"""A base connection object for server and kms connections."""
class AsyncConnection:
"""Store a connection with some metadata.
def __init__(self, conn: AsyncNetworkingInterface, opts: PoolOptions):
:param conn: a raw connection object
:param pool: a Pool instance
:param address: the server's (host, port)
:param id: the id of this socket in it's pool
:param is_sdam: SDAM connections do not call hello on creation
"""
def __init__(
self,
conn: AsyncNetworkingInterface,
pool: Pool,
address: tuple[str, int],
id: int,
is_sdam: bool,
):
self.pool_ref = weakref.ref(pool)
self.conn = conn
self.socket_checker: SocketChecker = SocketChecker()
self.cancel_context: _CancellationContext = _CancellationContext()
self.is_sdam = False
self.address = address
self.id = id
self.is_sdam = is_sdam
self.closed = False
self.last_timeout: float | None = None
self.more_to_come = False
self.opts = opts
self.max_wire_version = -1
self.last_checkin_time = time.monotonic()
self.performed_handshake = False
self.is_writable: bool = False
self.max_wire_version = MAX_WIRE_VERSION
self.max_bson_size = MAX_BSON_SIZE
self.max_message_size = MAX_MESSAGE_SIZE
self.max_write_batch_size = MAX_WRITE_BATCH_SIZE
self.supports_sessions = False
self.hello_ok: bool = False
self.is_mongos = False
self.op_msg_enabled = False
self.listeners = pool.opts._event_listeners
self.enabled_for_cmap = pool.enabled_for_cmap
self.enabled_for_logging = pool.enabled_for_logging
self.compression_settings = pool.opts._compression_settings
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
self.socket_checker: SocketChecker = SocketChecker()
self.oidc_token_gen_id: Optional[int] = None
# Support for mechanism negotiation on the initial handshake.
self.negotiated_mechs: Optional[list[str]] = None
self.auth_ctx: Optional[_AuthContext] = None
# The pool's generation changes with each reset() so we can close
# sockets created before the last reset.
self.pool_gen = pool.gen
self.generation = self.pool_gen.get_overall()
self.ready = False
self.cancel_context: _CancellationContext = _CancellationContext()
self.opts = pool.opts
self.more_to_come: bool = False
# For load balancer support.
self.service_id: Optional[ObjectId] = None
self.server_connection_id: Optional[int] = None
# When executing a transaction in load balancing mode, this flag is
# set to true to indicate that the session now owns the connection.
self.pinned_txn = False
self.pinned_cursor = False
self.active = False
self.last_timeout = self.opts.socket_timeout
self.connect_rtt = 0.0
self._client_id = pool._client_id
self.creation_time = time.monotonic()
# For gossiping $clusterTime from the connection handshake to the client.
self._cluster_time = None
def set_conn_timeout(self, timeout: Optional[float]) -> None:
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
@ -164,111 +208,17 @@ class AsyncBaseConnection:
formatted = format_timeout_details(timeout_details)
# CSOT: raise an error without running the command since we know it will time out.
errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}"
if self.max_wire_version != -1:
raise ExecutionTimeout(
errmsg,
50,
{"ok": 0, "errmsg": errmsg, "code": 50},
self.max_wire_version,
)
else:
raise TimeoutError(errmsg)
raise ExecutionTimeout(
errmsg,
50,
{"ok": 0, "errmsg": errmsg, "code": 50},
self.max_wire_version,
)
if cmd is not None:
cmd["maxTimeMS"] = int(max_time_ms * 1000)
self.set_conn_timeout(timeout)
return timeout
async def close_conn(self, reason: Optional[str]) -> None:
"""Close this connection with a reason."""
if self.closed:
return
await self._close_conn()
async def _close_conn(self) -> None:
"""Close this connection."""
if self.closed:
return
self.closed = True
self.cancel_context.cancel()
# Note: We catch exceptions to avoid spurious errors on interpreter
# shutdown.
try:
await self.conn.close()
except Exception: # noqa: S110
pass
def conn_closed(self) -> bool:
"""Return True if we know socket has been closed, False otherwise."""
if _IS_SYNC:
return self.socket_checker.socket_closed(self.conn.get_conn)
else:
return self.conn.is_closing()
class AsyncConnection(AsyncBaseConnection):
"""Store a connection with some metadata.
:param conn: a raw connection object
:param pool: a Pool instance
:param address: the server's (host, port)
:param id: the id of this socket in it's pool
:param is_sdam: SDAM connections do not call hello on creation
"""
def __init__(
self,
conn: AsyncNetworkingInterface,
pool: Pool,
address: tuple[str, int],
id: int,
is_sdam: bool,
):
super().__init__(conn, pool.opts)
self.pool_ref = weakref.ref(pool)
self.address: tuple[str, int] = address
self.id: int = id
self.is_sdam = is_sdam
self.last_checkin_time = time.monotonic()
self.performed_handshake = False
self.is_writable: bool = False
self.max_wire_version = MAX_WIRE_VERSION
self.max_bson_size: int = MAX_BSON_SIZE
self.max_message_size: int = MAX_MESSAGE_SIZE
self.max_write_batch_size: int = MAX_WRITE_BATCH_SIZE
self.supports_sessions = False
self.hello_ok: bool = False
self.is_mongos: bool = False
self.op_msg_enabled = False
self.listeners = pool.opts._event_listeners
self.enabled_for_cmap = pool.enabled_for_cmap
self.enabled_for_logging = pool.enabled_for_logging
self.compression_settings = pool.opts._compression_settings
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
self.oidc_token_gen_id: Optional[int] = None
# Support for mechanism negotiation on the initial handshake.
self.negotiated_mechs: Optional[list[str]] = None
self.auth_ctx: Optional[_AuthContext] = None
# The pool's generation changes with each reset() so we can close
# sockets created before the last reset.
self.pool_gen = pool.gen
self.generation = self.pool_gen.get_overall()
self.ready = False
# For load balancer support.
self.service_id: Optional[ObjectId] = None
self.server_connection_id: Optional[int] = None
# When executing a transaction in load balancing mode, this flag is
# set to true to indicate that the session now owns the connection.
self.pinned_txn = False
self.pinned_cursor = False
self.active = False
self.last_timeout = self.opts.socket_timeout
self.connect_rtt = 0.0
self._client_id = pool._client_id
self.creation_time = time.monotonic()
# For gossiping $clusterTime from the connection handshake to the client.
self._cluster_time = None
def pin_txn(self) -> None:
self.pinned_txn = True
assert not self.pinned_cursor
@ -304,6 +254,7 @@ class AsyncConnection(AsyncBaseConnection):
cmd = self.hello_cmd()
performing_handshake = not self.performed_handshake
awaitable = False
cmd["backpressure"] = True
if performing_handshake:
self.performed_handshake = True
cmd["client"] = self.opts.metadata
@ -612,6 +563,26 @@ class AsyncConnection(AsyncBaseConnection):
error=reason,
)
async def _close_conn(self) -> None:
"""Close this connection."""
if self.closed:
return
self.closed = True
self.cancel_context.cancel()
# Note: We catch exceptions to avoid spurious errors on interpreter
# shutdown.
try:
await self.conn.close()
except Exception: # noqa: S110
pass
def conn_closed(self) -> bool:
"""Return True if we know socket has been closed, False otherwise."""
if _IS_SYNC:
return self.socket_checker.socket_closed(self.conn.get_conn)
else:
return self.conn.is_closing()
def send_cluster_time(
self,
command: MutableMapping[str, Any],
@ -647,7 +618,7 @@ class AsyncConnection(AsyncBaseConnection):
# signals and throws KeyboardInterrupt into the current frame on the
# main thread.
#
# But in Gevent and Eventlet, the polling mechanism (epoll, kqueue,
# But in Gevent, the polling mechanism (epoll, kqueue,
# ..) is called in Python code, which experiences the signal as a
# KeyboardInterrupt from the start, rather than as an initial
# socket.error, so we catch that, close the socket, and reraise it.
@ -725,8 +696,6 @@ class PoolState:
CLOSED = 3
# Do *not* explicitly inherit from object or Jython won't call __del__
# https://bugs.jython.org/issue1057
class Pool:
def __init__(
self,
@ -789,8 +758,8 @@ class Pool:
# Also used for: clearing the wait queue
self._max_connecting_cond = _async_create_condition(self.lock)
self._pending = 0
self._max_connecting = self.opts.max_connecting
self._client_id = client_id
self._backoff = 0
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_pool_created(
@ -846,8 +815,6 @@ class Pool:
async with self.size_cond:
if self.closed:
return
# Clear the backoff state.
self._backoff = 0
if self.opts.pause_enabled and pause and not self.opts.load_balanced:
old_state, self.state = self.state, PoolState.PAUSED
self.gen.inc(service_id)
@ -930,11 +897,6 @@ class Pool:
for conn in sockets:
await conn.close_conn(ConnectionClosedReason.STALE)
@property
def max_connecting(self) -> int:
"""The current max connecting limit for the pool."""
return 1 if self._backoff else self.opts.max_connecting
async def update_is_writable(self, is_writable: Optional[bool]) -> None:
"""Updates the is_writable attribute on all sockets currently in the
Pool.
@ -1001,7 +963,7 @@ class Pool:
async with self._max_connecting_cond:
# If maxConnecting connections are already being created
# by this pool then try again later instead of waiting.
if self._pending >= self.max_connecting:
if self._pending >= self._max_connecting:
return
self._pending += 1
incremented = True
@ -1029,29 +991,20 @@ class Pool:
self.requests -= 1
self.size_cond.notify()
def _handle_connection_error(self, error: BaseException, phase: str, conn_id: int) -> None:
def _handle_connection_error(self, error: BaseException) -> None:
# Handle system overload condition for non-sdam pools.
# Look for an AutoReconnect error raised from a ConnectionResetError with
# errno == errno.ECONNRESET or raised from an OSError that we've created due to
# a closed connection.
# If found, set backoff and add error labels.
if self.is_sdam or type(error) != AutoReconnect:
# Look for errors of type AutoReconnect and add error labels if appropriate.
if self.is_sdam or type(error) not in (AutoReconnect, NetworkTimeout):
return
self._backoff += 1
assert isinstance(error, AutoReconnect) # Appease type checker.
# If the original error was a DNS, certificate, or SSL error, ignore it.
if isinstance(error.__cause__, (_CertificateError, SSLErrors, socket.gaierror)):
# End of file errors are excluded, because the server may have disconnected
# during the handshake.
if not isinstance(error.__cause__, (ssl.SSLEOFError, ssl.SSLZeroReturnError)):
return
error._add_error_label("SystemOverloadedError")
error._add_error_label("RetryableError")
# Log the pool backoff message.
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_CONNECTION_LOGGER,
message=_ConnectionStatusMessage.POOL_BACKOFF,
clientId=self._client_id,
serverHost=self.address[0],
serverPort=self.address[1],
driverConnectionId=conn_id,
reason=_verbose_connection_error_reason(ConnectionClosedReason.POOL_BACKOFF),
error=ConnectionClosedReason.POOL_BACKOFF,
)
async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> AsyncConnection:
"""Connect to Mongo and return a new AsyncConnection.
@ -1082,17 +1035,8 @@ class Pool:
driverConnectionId=conn_id,
)
# Apply backoff if applicable.
if self._backoff:
await asyncio.sleep(_backoff(self._backoff))
# Pass a context to determine if we successfully create a configured socket.
context = dict(has_created_socket=False)
try:
networking_interface = await _configured_protocol_interface(
self.address, self.opts, context=context
)
networking_interface = await _configured_protocol_interface(self.address, self.opts)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
async with self.lock:
@ -1113,8 +1057,7 @@ class Pool:
reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR),
error=ConnectionClosedReason.ERROR,
)
if context["has_created_socket"]:
self._handle_connection_error(error, "handshake", conn_id)
self._handle_connection_error(error)
if isinstance(error, (IOError, OSError, *SSLErrors)):
details = _get_timeout_details(self.opts)
_raise_connection_failure(self.address, error, timeout_details=details)
@ -1126,9 +1069,11 @@ class Pool:
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
await conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)
@ -1138,15 +1083,14 @@ class Pool:
except BaseException as e:
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
self._handle_connection_error(e, "hello", conn_id)
if not completed_hello:
self._handle_connection_error(e)
await conn.close_conn(ConnectionClosedReason.ERROR)
raise
if handler:
await handler.client._topology.receive_cluster_time(conn._cluster_time)
# Clear the backoff state.
self._backoff = 0
return conn
@contextlib.asynccontextmanager
@ -1323,12 +1267,12 @@ class Pool:
# to be checked back into the pool.
async with self._max_connecting_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=False)
while not (self.conns or self._pending < self.max_connecting):
while not (self.conns or self._pending < self._max_connecting):
timeout = deadline - time.monotonic() if deadline else None
if not await _async_cond_wait(self._max_connecting_cond, timeout):
# Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition.
if self.conns or self._pending < self.max_connecting:
if self.conns or self._pending < self._max_connecting:
self._max_connecting_cond.notify()
emitted_event = True
self._raise_wait_queue_timeout(checkout_started_time)
@ -1469,7 +1413,7 @@ class Pool:
:class:`~pymongo.errors.AutoReconnect` exceptions on server
hiccups, etc. We only check if the socket was closed by an external
error if it has been > 1 second since the socket was checked into the
pool, or we are in backoff mode, to keep performance reasonable -
pool to keep performance reasonable -
we can't avoid AutoReconnects completely anyway.
"""
idle_time_seconds = conn.idle_time_seconds()
@ -1482,8 +1426,6 @@ class Pool:
return True
check_interval_seconds = self._check_interval_seconds
if self._backoff:
check_interval_seconds = 0
if check_interval_seconds is not None and (
check_interval_seconds == 0 or idle_time_seconds > check_interval_seconds
):

View File

@ -50,20 +50,11 @@ async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
if _IS_SYNC:
from dns import resolver
if hasattr(resolver, "resolve"):
# dnspython >= 2
return resolver.resolve(*args, **kwargs)
# dnspython 1.X
return resolver.query(*args, **kwargs)
return resolver.resolve(*args, **kwargs)
else:
from dns import asyncresolver
if hasattr(asyncresolver, "resolve"):
# dnspython >= 2
return await asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value]
raise ConfigurationError(
"Upgrade to dnspython version >= 2.0 to use AsyncMongoClient with mongodb+srv:// connections."
)
return await asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value]
_INVALID_HOST_MSG = (
@ -107,7 +98,7 @@ class _SrvResolver:
# No TXT records
return None
except Exception as exc:
raise ConfigurationError(str(exc)) from None
raise ConfigurationError(str(exc)) from exc
if len(results) > 1:
raise ConfigurationError("Only one TXT record is supported")
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") # type: ignore[attr-defined]
@ -122,7 +113,7 @@ class _SrvResolver:
# Raise the original error.
raise
# Else, raise all errors as ConfigurationError.
raise ConfigurationError(str(exc)) from None
raise ConfigurationError(str(exc)) from exc
return results
async def _get_srv_response_and_hosts(
@ -145,8 +136,8 @@ class _SrvResolver:
)
try:
nlist = srv_host.split(".")[1:][-self.__slen :]
except Exception:
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
except Exception as exc:
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc
if self.__plist != nlist:
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
if self.__srv_max_hosts:

View File

@ -111,7 +111,7 @@ class Topology:
self._publish_tp = self._listeners is not None and self._listeners.enabled_for_topology
# Create events queue if there are publishers.
self._events = None
self._events: queue.Queue[Any] | None = None
self.__events_executor: Any = None
if self._publish_server or self._publish_tp:
@ -126,6 +126,7 @@ class Topology:
if self._publish_tp:
assert self._events is not None
assert self._listeners is not None
self._events.put((self._listeners.publish_topology_opened, (self._topology_id,)))
self._settings = topology_settings
topology_description = TopologyDescription(
@ -143,6 +144,7 @@ class Topology:
)
if self._publish_tp:
assert self._events is not None
assert self._listeners is not None
self._events.put(
(
self._listeners.publish_topology_description_changed,
@ -161,6 +163,7 @@ class Topology:
for seed in topology_settings.seeds:
if self._publish_server:
assert self._events is not None
assert self._listeners is not None
self._events.put((self._listeners.publish_server_opened, (seed, self._topology_id)))
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
@ -265,6 +268,7 @@ class Topology:
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
operation_id: Optional[int] = None,
deprioritized_servers: Optional[list[Server]] = None,
) -> list[Server]:
"""Return a list of Servers matching selector, or time out.
@ -292,7 +296,12 @@ class Topology:
async with self._lock:
server_descriptions = await self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
selector,
server_timeout,
operation,
operation_id,
address,
deprioritized_servers=deprioritized_servers,
)
return [
@ -306,6 +315,7 @@ class Topology:
operation: str,
operation_id: Optional[int],
address: Optional[_Address],
deprioritized_servers: Optional[list[Server]] = None,
) -> list[ServerDescription]:
"""select_servers() guts. Hold the lock when calling this."""
now = time.monotonic()
@ -324,7 +334,12 @@ class Topology:
)
server_descriptions = self._description.apply_selector(
selector, address, custom_selector=self._settings.server_selector
selector,
address,
custom_selector=self._settings.server_selector,
deprioritized_servers=[server.description for server in deprioritized_servers]
if deprioritized_servers
else None,
)
while not server_descriptions:
@ -385,9 +400,13 @@ class Topology:
operation_id: Optional[int] = None,
) -> Server:
servers = await self.select_servers(
selector, operation, server_selection_timeout, address, operation_id
selector,
operation,
server_selection_timeout,
address,
operation_id,
deprioritized_servers,
)
servers = _filter_servers(servers, deprioritized_servers)
if len(servers) == 1:
return servers[0]
server1, server2 = random.sample(servers, 2)
@ -491,6 +510,7 @@ class Topology:
suppress_event = sd_old == server_description
if self._publish_server and not suppress_event:
assert self._events is not None
assert self._listeners is not None
self._events.put(
(
self._listeners.publish_server_description_changed,
@ -503,6 +523,7 @@ class Topology:
if self._publish_tp and not suppress_event:
assert self._events is not None
assert self._listeners is not None
self._events.put(
(
self._listeners.publish_topology_description_changed,
@ -570,6 +591,7 @@ class Topology:
if self._publish_tp:
assert self._events is not None
assert self._listeners is not None
self._events.put(
(
self._listeners.publish_topology_description_changed,
@ -723,6 +745,7 @@ class Topology:
# Publish only after releasing the lock.
if self._publish_tp:
assert self._events is not None
assert self._listeners is not None
self._description = TopologyDescription(
TOPOLOGY_TYPE.Unknown,
{},
@ -890,8 +913,8 @@ class Topology:
# Clear the pool.
await server.reset(service_id)
elif isinstance(error, ConnectionFailure):
if isinstance(error, WaitQueueTimeoutError) or error.has_error_label(
"SystemOverloadedError"
if isinstance(error, WaitQueueTimeoutError) or (
error.has_error_label("SystemOverloadedError")
):
return
# "Client MUST replace the server's description with type Unknown
@ -1114,16 +1137,3 @@ def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDe
if current_tv["processId"] != new_tv["processId"]:
return False
return current_tv["counter"] > new_tv["counter"]
def _filter_servers(
candidates: list[Server], deprioritized_servers: Optional[list[Server]] = None
) -> list[Server]:
"""Filter out deprioritized servers from a list of server candidates."""
if not deprioritized_servers:
return candidates
filtered = [server for server in candidates if server not in deprioritized_servers]
# If not possible to pick a prioritized server, return the original list
return filtered or candidates

View File

@ -159,6 +159,7 @@ def _build_credentials_tuple(
"localhost",
"127.0.0.1",
"::1",
"*.mongo.com",
]
allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed)
if properties.get("ALLOWED_HOSTS", None) is not None and human_callback is None:

View File

@ -235,6 +235,11 @@ class ClientOptions:
self.__server_monitoring_mode = options.get(
"servermonitoringmode", common.SERVER_MONITORING_MODE
)
self.__adaptive_retries = (
options.get("adaptive_retries", common.ADAPTIVE_RETRIES)
if "adaptive_retries" in options
else options.get("adaptiveretries", common.ADAPTIVE_RETRIES)
)
@property
def _options(self) -> Mapping[str, Any]:
@ -346,3 +351,11 @@ class ClientOptions:
.. versionadded:: 4.5
"""
return self.__server_monitoring_mode
@property
def adaptive_retries(self) -> bool:
"""The configured adaptiveRetries option.
.. versionadded:: 4.XX
"""
return self.__adaptive_retries

View File

@ -20,6 +20,7 @@ import datetime
import warnings
from collections import OrderedDict, abc
from difflib import get_close_matches
from importlib.metadata import requires, version
from typing import (
TYPE_CHECKING,
Any,
@ -139,6 +140,9 @@ SRV_SERVICE_NAME = "mongodb"
# Default value for serverMonitoringMode
SERVER_MONITORING_MODE = "auto" # poll/stream/auto
# Default value for adaptiveRetries
ADAPTIVE_RETRIES = False
# Auth mechanism properties that must raise an error instead of warning if they invalidate.
_MECH_PROP_MUST_RAISE = ["CANONICALIZE_HOST_NAME"]
@ -737,6 +741,7 @@ URI_OPTIONS_VALIDATOR_MAP: dict[str, Callable[[Any, Any], Any]] = {
"srvmaxhosts": validate_non_negative_integer,
"timeoutms": validate_timeoutms,
"servermonitoringmode": validate_server_monitoring_mode,
"adaptiveretries": validate_boolean_or_string,
}
# Dictionary where keys are the names of URI options specific to pymongo,
@ -770,6 +775,7 @@ KW_VALIDATORS: dict[str, Callable[[Any, Any], Any]] = {
"server_selector": validate_is_callable_or_none,
"auto_encryption_opts": validate_auto_encryption_opts_or_none,
"authoidcallowedhosts": validate_list,
"adaptive_retries": validate_boolean_or_string,
}
# Dictionary where keys are any URI option name, and values are the
@ -1092,3 +1098,91 @@ def has_c() -> bool:
return True
except ImportError:
return False
class Version(tuple[int, ...]):
"""A class that can be used to compare version strings."""
def __new__(cls, *version: int) -> Version:
padded_version = cls._padded(version, 4)
return super().__new__(cls, tuple(padded_version))
@classmethod
def _padded(cls, iter: Any, length: int, padding: int = 0) -> list[int]:
as_list = list(iter)
if len(as_list) < length:
for _ in range(length - len(as_list)):
as_list.append(padding)
return as_list
@classmethod
def from_string(cls, version_string: str) -> Version:
mod = 0
bump_patch_level = False
if version_string.endswith("+"):
version_string = version_string[0:-1]
mod = 1
elif version_string.endswith("-pre-"):
version_string = version_string[0:-5]
mod = -1
elif version_string.endswith("-"):
version_string = version_string[0:-1]
mod = -1
# Deal with .devX substrings
if ".dev" in version_string:
version_string = version_string[0 : version_string.find(".dev")]
mod = -1
# Deal with '-rcX' substrings
if "-rc" in version_string:
version_string = version_string[0 : version_string.find("-rc")]
mod = -1
# Deal with git describe generated substrings
elif "-" in version_string:
version_string = version_string[0 : version_string.find("-")]
mod = -1
bump_patch_level = True
version = [int(part) for part in version_string.split(".")]
version = cls._padded(version, 3)
# Make from_string and from_version_array agree. For example:
# MongoDB Enterprise > db.runCommand('buildInfo').versionArray
# [ 3, 2, 1, -100 ]
# MongoDB Enterprise > db.runCommand('buildInfo').version
# 3.2.0-97-g1ef94fe
if bump_patch_level:
version[-1] += 1
version.append(mod)
return Version(*version)
@classmethod
def from_version_array(cls, version_array: Any) -> Version:
version = list(version_array)
if version[-1] < 0:
version[-1] = -1
version = cls._padded(version, 3)
return Version(*version)
def at_least(self, *other_version: Any) -> bool:
return self >= Version(*other_version)
def __str__(self) -> str:
return ".".join(map(str, self))
def check_for_min_version(package_name: str) -> tuple[str, str, bool]:
"""Test whether an installed package is of the desired version."""
package_version_str = version(package_name)
package_version = Version.from_string(package_version_str)
# Dependency is expected to be in one of the forms:
# "pymongocrypt<2.0.0,>=1.13.0; extra == 'encryption'"
# 'dnspython<3.0.0,>=1.16.0'
#
requirements = requires("pymongo")
assert requirements is not None
requirement = [i for i in requirements if i.startswith(package_name)][0] # noqa: RUF015
if ";" in requirement:
requirement = requirement.split(";")[0]
required_version = requirement[requirement.find(">=") + 2 :]
is_valid = package_version >= Version.from_string(required_version)
return package_version_str, required_version, is_valid

View File

@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations
import sys
import warnings
from typing import Any, Iterable, Optional, Union
@ -44,7 +45,10 @@ def _have_zlib() -> bool:
def _have_zstd() -> bool:
try:
import zstandard # noqa: F401
if sys.version_info >= (3, 14):
from compression import zstd
else:
from backports import zstd # noqa: F401
return True
except ImportError:
@ -79,11 +83,18 @@ def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[s
)
elif compressor == "zstd" and not _have_zstd():
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with zstandard is not available. "
"You must install the zstandard module for zstandard support.",
stacklevel=2,
)
if sys.version_info >= (3, 14):
warnings.warn(
"Wire protocol compression with zstandard is not available. "
"The compression.zstd module is not available.",
stacklevel=2,
)
else:
warnings.warn(
"Wire protocol compression with zstandard is not available. "
"You must install the backports.zstd module for zstandard support.",
stacklevel=2,
)
return compressors
@ -144,15 +155,15 @@ class ZstdContext:
@staticmethod
def compress(data: bytes) -> bytes:
# ZstdCompressor is not thread safe.
# TODO: Use a pool?
if sys.version_info >= (3, 14):
from compression import zstd
else:
from backports import zstd
import zstandard
return zstandard.ZstdCompressor().compress(data)
return zstd.compress(data)
def decompress(data: bytes, compressor_id: int) -> bytes:
def decompress(data: bytes | memoryview, compressor_id: int) -> bytes:
if compressor_id == SnappyContext.compressor_id:
# python-snappy doesn't support the buffer interface.
# https://github.com/andrix/python-snappy/issues/65
@ -166,10 +177,11 @@ def decompress(data: bytes, compressor_id: int) -> bytes:
return zlib.decompress(data)
elif compressor_id == ZstdContext.compressor_id:
# ZstdDecompressor is not thread safe.
# TODO: Use a pool?
import zstandard
if sys.version_info >= (3, 14):
from compression import zstd
else:
from backports import zstd
return zstandard.ZstdDecompressor().decompress(data)
return zstd.decompress(data)
else:
raise ValueError("Unknown compressorId %d" % (compressor_id,))

View File

@ -16,7 +16,104 @@
"""Constants and types shared across all cursor classes."""
from __future__ import annotations
from typing import Any, Mapping, Sequence, Tuple, Union
from abc import ABC, abstractmethod
from typing import Any, Generic, Mapping, Optional, Sequence, Tuple, Union
from pymongo.message import _CursorAddress
from pymongo.typings import _Address, _DocumentType
class _AgnosticCursorBase(Generic[_DocumentType], ABC):
"""
Shared IO-agnostic cursor base used by both async and sync cursor classes.
All IO-specific behavior is implemented in subclasses.
"""
# These are all typed more accurately in subclasses.
_collection: Any
_id: Optional[int]
_data: Any
_address: Optional[_Address]
_sock_mgr: Any
_session: Optional[Any]
_killed: bool
@abstractmethod
def _get_namespace(self) -> str:
"""Return the full namespace (dbname.collname) for this cursor."""
...
def __del__(self) -> None:
self._die_no_lock()
@property
def alive(self) -> bool:
"""Does this cursor have the potential to return more data?
This is mostly useful with `tailable cursors
<https://www.mongodb.com/docs/manual/core/tailable-cursors/>`_
since they will stop iterating even though they *may* return more
results in the future.
With regular cursors, simply use an asynchronous for loop instead of :attr:`alive`::
async for doc in collection.find():
print(doc)
.. note:: Even if :attr:`alive` is True, :meth:`next` can raise
:exc:`StopIteration`. :attr:`alive` can also be True while iterating
a cursor from a failed server. In this case :attr:`alive` will
return False after :meth:`next` fails to retrieve the next batch
of results from the server.
"""
return bool(len(self._data) or (not self._killed))
@property
def cursor_id(self) -> Optional[int]:
"""Returns the id of the cursor.
.. versionadded:: 2.2
"""
return self._id
@property
def address(self) -> Optional[_Address]:
"""The (host, port) of the server used, or None.
.. versionchanged:: 3.0
Renamed from "conn_id".
"""
return self._address
def _prepare_to_die(self, already_killed: bool) -> tuple[int, Optional[_CursorAddress]]:
self._killed = True
if self._id and not already_killed:
cursor_id = self._id
assert self._address is not None
address = _CursorAddress(self._address, self._get_namespace())
else:
# Skip killCursors.
cursor_id = 0
address = None
return cursor_id, address
def _die_no_lock(self) -> None:
"""Closes this cursor without acquiring a lock."""
try:
already_killed = self._killed
except AttributeError:
# ___init__ did not run to completion (or at all).
return
cursor_id, address = self._prepare_to_die(already_killed)
self._collection.database.client._cleanup_cursor_no_lock(
cursor_id, address, self._sock_mgr, self._session
)
if self._session and self._session._implicit:
self._session._attached_to_cursor = False
self._session = None
self._sock_mgr = None
# These errors mean that the server has already killed the cursor so there is
# no need to send killCursors.

View File

@ -18,12 +18,12 @@
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional
from typing import TYPE_CHECKING, Any, Mapping, Optional, TypedDict
from pymongo.uri_parser_shared import _parse_kms_tls_options
try:
import pymongocrypt # type:ignore[import-untyped] # noqa: F401
import pymongocrypt # type:ignore[import-untyped] # noqa: F401
# Check for pymongocrypt>=1.10.
from pymongocrypt import synchronous as _ # noqa: F401
@ -32,7 +32,7 @@ try:
except ImportError:
_HAVE_PYMONGOCRYPT = False
from bson import int64
from pymongo.common import validate_is_mapping
from pymongo.common import check_for_min_version, validate_is_mapping
from pymongo.errors import ConfigurationError
if TYPE_CHECKING:
@ -40,6 +40,18 @@ if TYPE_CHECKING:
from pymongo.typings import _AgnosticMongoClient
def check_min_pymongocrypt() -> None:
"""Raise an appropriate error if the min pymongocrypt is not installed."""
pymongocrypt_version, required_version, is_valid = check_for_min_version("pymongocrypt")
if not is_valid:
raise ConfigurationError(
f"client side encryption requires pymongocrypt>={required_version}, "
f"found version {pymongocrypt_version}. "
"Install a compatible version with: "
"python -m pip install 'pymongo[encryption]'"
)
class AutoEncryptionOpts:
"""Options to configure automatic client-side field level encryption."""
@ -215,6 +227,7 @@ class AutoEncryptionOpts:
"install a compatible version with: "
"python -m pip install 'pymongo[encryption]'"
)
check_min_pymongocrypt()
if encrypted_fields_map:
validate_is_mapping("encrypted_fields_map", encrypted_fields_map)
self._encrypted_fields_map = encrypted_fields_map
@ -295,3 +308,85 @@ class RangeOpts:
if v is not None:
doc[k] = v
return doc
class TextOpts:
"""**BETA** Options to configure encrypted queries using the text algorithm.
TextOpts is currently unstable API and subject to backwards breaking changes."""
def __init__(
self,
substring: Optional[SubstringOpts] = None,
prefix: Optional[PrefixOpts] = None,
suffix: Optional[SuffixOpts] = None,
case_sensitive: Optional[bool] = None,
diacritic_sensitive: Optional[bool] = None,
) -> None:
"""Options to configure encrypted queries using the text algorithm.
:param substring: Further options to support substring queries.
:param prefix: Further options to support prefix queries.
:param suffix: Further options to support suffix queries.
:param case_sensitive: Whether text indexes for this field are case sensitive.
:param diacritic_sensitive: Whether text indexes for this field are diacritic sensitive.
.. versionadded:: 4.15
"""
self.substring = substring
self.prefix = prefix
self.suffix = suffix
self.case_sensitive = case_sensitive
self.diacritic_sensitive = diacritic_sensitive
@property
def document(self) -> dict[str, Any]:
doc = {}
for k, v in [
("substring", self.substring),
("prefix", self.prefix),
("suffix", self.suffix),
("caseSensitive", self.case_sensitive),
("diacriticSensitive", self.diacritic_sensitive),
]:
if v is not None:
doc[k] = v
return doc
class SubstringOpts(TypedDict):
"""**BETA** Options for substring text queries.
SubstringOpts is currently unstable API and subject to backwards breaking changes.
"""
# strMaxLength is the maximum allowed length to insert. Inserting longer strings will error.
strMaxLength: int
# strMinQueryLength is the minimum allowed query length. Querying with a shorter string will error.
strMinQueryLength: int
# strMaxQueryLength is the maximum allowed query length. Querying with a longer string will error.
strMaxQueryLength: int
class PrefixOpts(TypedDict):
"""**BETA** Options for prefix text queries.
PrefixOpts is currently unstable API and subject to backwards breaking changes.
"""
# strMinQueryLength is the minimum allowed query length. Querying with a shorter string will error.
strMinQueryLength: int
# strMaxQueryLength is the maximum allowed query length. Querying with a longer string will error.
strMaxQueryLength: int
class SuffixOpts(TypedDict):
"""**BETA** Options for suffix text queries.
SuffixOpts is currently unstable API and subject to backwards breaking changes.
"""
# strMinQueryLength is the minimum allowed query length. Querying with a shorter string will error.
strMinQueryLength: int
# strMaxQueryLength is the maximum allowed query length. Querying with a longer string will error.
strMaxQueryLength: int

View File

@ -1298,8 +1298,6 @@ def _batched_write_command_impl(
# Start of payload
buf.seek(-1, 2)
# Work around some Jython weirdness.
buf.truncate()
try:
buf.write(_OP_MAP[operation])
except KeyError:
@ -1352,7 +1350,9 @@ class _OpReply:
UNPACK_FROM = struct.Struct("<iqii").unpack_from
OP_CODE = 1
def __init__(self, flags: int, cursor_id: int, number_returned: int, documents: bytes):
def __init__(
self, flags: int, cursor_id: int, number_returned: int, documents: bytes | memoryview
):
self.flags = flags
self.cursor_id = Int64(cursor_id)
self.number_returned = number_returned
@ -1360,7 +1360,7 @@ class _OpReply:
def raw_response(
self, cursor_id: Optional[int] = None, user_fields: Optional[Mapping[str, Any]] = None
) -> list[bytes]:
) -> list[bytes | memoryview]:
"""Check the response header from the database, without decoding BSON.
Check the response for errors and unpack.
@ -1448,7 +1448,7 @@ class _OpReply:
return False
@classmethod
def unpack(cls, msg: bytes) -> _OpReply:
def unpack(cls, msg: bytes | memoryview) -> _OpReply:
"""Construct an _OpReply from raw bytes."""
# PYTHON-945: ignore starting_from field.
flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg)
@ -1470,7 +1470,7 @@ class _OpMsg:
MORE_TO_COME = 1 << 1
EXHAUST_ALLOWED = 1 << 16 # Only present on requests.
def __init__(self, flags: int, payload_document: bytes):
def __init__(self, flags: int, payload_document: bytes | memoryview):
self.flags = flags
self.payload_document = payload_document
@ -1512,7 +1512,7 @@ class _OpMsg:
"""Unpack a command response."""
return self.unpack_response(codec_options=codec_options)[0]
def raw_command_response(self) -> bytes:
def raw_command_response(self) -> bytes | memoryview:
"""Return the bytes of the command response."""
return self.payload_document
@ -1522,7 +1522,7 @@ class _OpMsg:
return bool(self.flags & self.MORE_TO_COME)
@classmethod
def unpack(cls, msg: bytes) -> _OpMsg:
def unpack(cls, msg: bytes | memoryview) -> _OpMsg:
"""Construct an _OpMsg from raw bytes."""
flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg)
if flags != 0:
@ -1541,7 +1541,7 @@ class _OpMsg:
return cls(flags, payload_document)
_UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = {
_UNPACK_REPLY: dict[int, Callable[[bytes | memoryview], Union[_OpReply, _OpMsg]]] = {
_OpReply.OP_CODE: _OpReply.unpack,
_OpMsg.OP_CODE: _OpMsg.unpack,
}

Some files were not shown because too many files have changed in this diff Show More