From 89f4e5c786d4439c92914614474306910ccc8142 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 20 Nov 2024 09:21:30 -0600 Subject: [PATCH 1/9] PYTHON-3730 Ensure C extensions when running the test suite (#2013) --- .evergreen/run-tests.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 36fa76e31..1e03e2714 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -47,6 +47,11 @@ else echo "Not sourcing secrets" fi +# Ensure C extensions have compiled. +if [ -z "${NO_EXT:-}" ]; then + python tools/fail_if_no_c.py +fi + if [ "$AUTH" != "noauth" ]; then if [ ! -z "$TEST_DATA_LAKE" ]; then export DB_USER="mhuser" From 906d021bb1a8b4c381ce79f943c5c57b4314ecb7 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 20 Nov 2024 11:56:10 -0600 Subject: [PATCH 2/9] PYTHON-4447 Test OIDC on Server Latest (#2014) --- .evergreen/config.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index fc1713a88..1e4996c28 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -922,9 +922,6 @@ task_groups: params: binary: bash include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] - env: - # PYTHON-4447 - MONGODB_VERSION: "8.0" args: - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/setup.sh teardown_task: From 1c7a7fe9ec3119228bc7bf98f1f9de199a4f8f2c Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 20 Nov 2024 14:47:28 -0500 Subject: [PATCH 3/9] PYTHON-4721 - Create individualized scripts for all shell.exec commands (#1997) Co-authored-by: Jib --- .evergreen/config.yml | 706 ++++++++---------- .evergreen/hatch.sh | 2 +- .evergreen/run-tests.sh | 8 +- .evergreen/scripts/archive-mongodb-logs.sh | 8 + .../scripts/bootstrap-mongo-orchestration.sh | 46 ++ .evergreen/scripts/check-import-time.sh | 7 + .evergreen/scripts/cleanup.sh | 7 + .evergreen/scripts/configure-env.sh | 19 +- .../scripts/download-and-merge-coverage.sh | 4 + .evergreen/scripts/fix-absolute-paths.sh | 8 + .evergreen/scripts/init-test-results.sh | 5 + .evergreen/scripts/install-dependencies.sh | 6 + .evergreen/scripts/make-files-executable.sh | 8 + .evergreen/scripts/prepare-resources.sh | 12 + .evergreen/scripts/run-atlas-tests.sh | 7 + .evergreen/scripts/run-aws-ecs-auth-test.sh | 15 + .evergreen/scripts/run-doctests.sh | 4 + .../scripts/run-enterprise-auth-tests.sh | 6 + .evergreen/scripts/run-gcpkms-fail-test.sh | 7 + .evergreen/scripts/run-getdata.sh | 22 + .evergreen/scripts/run-load-balancer.sh | 3 + .evergreen/scripts/run-mockupdb-tests.sh | 5 + .../{ => scripts}/run-mod-wsgi-tests.sh | 2 +- .../{ => scripts}/run-mongodb-aws-test.sh | 10 +- .evergreen/scripts/run-ocsp-test.sh | 8 + .evergreen/scripts/run-perf-tests.sh | 4 + .evergreen/scripts/run-tests.sh | 55 ++ .evergreen/scripts/run-with-env.sh | 21 + .evergreen/scripts/setup-encryption.sh | 5 + .evergreen/scripts/setup-tests.sh | 27 + .evergreen/scripts/stop-load-balancer.sh | 5 + .evergreen/scripts/teardown-aws.sh | 7 + .evergreen/scripts/teardown-docker.sh | 7 + .evergreen/scripts/upload-coverage-report.sh | 3 + .evergreen/scripts/windows-fix.sh | 11 + .evergreen/setup-encryption.sh | 7 +- .evergreen/utils.sh | 4 +- test/asynchronous/test_client_context.py | 8 +- test/mod_wsgi_test/README.rst | 2 +- test/test_client_context.py | 8 +- 40 files changed, 681 insertions(+), 428 deletions(-) create mode 100644 .evergreen/scripts/archive-mongodb-logs.sh create mode 100644 .evergreen/scripts/bootstrap-mongo-orchestration.sh create mode 100644 .evergreen/scripts/check-import-time.sh create mode 100644 .evergreen/scripts/cleanup.sh create mode 100644 .evergreen/scripts/download-and-merge-coverage.sh create mode 100644 .evergreen/scripts/fix-absolute-paths.sh create mode 100644 .evergreen/scripts/init-test-results.sh create mode 100644 .evergreen/scripts/install-dependencies.sh create mode 100644 .evergreen/scripts/make-files-executable.sh create mode 100644 .evergreen/scripts/prepare-resources.sh create mode 100644 .evergreen/scripts/run-atlas-tests.sh create mode 100644 .evergreen/scripts/run-aws-ecs-auth-test.sh create mode 100644 .evergreen/scripts/run-doctests.sh create mode 100644 .evergreen/scripts/run-enterprise-auth-tests.sh create mode 100644 .evergreen/scripts/run-gcpkms-fail-test.sh create mode 100644 .evergreen/scripts/run-getdata.sh create mode 100644 .evergreen/scripts/run-load-balancer.sh create mode 100644 .evergreen/scripts/run-mockupdb-tests.sh rename .evergreen/{ => scripts}/run-mod-wsgi-tests.sh (97%) rename .evergreen/{ => scripts}/run-mongodb-aws-test.sh (67%) create mode 100644 .evergreen/scripts/run-ocsp-test.sh create mode 100644 .evergreen/scripts/run-perf-tests.sh create mode 100644 .evergreen/scripts/run-tests.sh create mode 100644 .evergreen/scripts/run-with-env.sh create mode 100644 .evergreen/scripts/setup-encryption.sh create mode 100644 .evergreen/scripts/setup-tests.sh create mode 100644 .evergreen/scripts/stop-load-balancer.sh create mode 100644 .evergreen/scripts/teardown-aws.sh create mode 100644 .evergreen/scripts/teardown-docker.sh create mode 100644 .evergreen/scripts/upload-coverage-report.sh create mode 100644 .evergreen/scripts/windows-fix.sh diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 1e4996c28..59b8a543f 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -20,10 +20,9 @@ exec_timeout_secs: 3600 # 60 minutes is the longest we'll ever run (primarily # What to do when evergreen hits the timeout (`post:` tasks are run automatically) timeout: - - command: shell.exec + - command: subprocess.exec params: - script: | - ls -la + binary: ls -la include: - filename: .evergreen/generated_configs/tasks.yml @@ -41,7 +40,7 @@ functions: # Make an evergreen expansion file with dynamic values - command: subprocess.exec params: - include_expansions_in_env: ["is_patch", "project", "version_id"] + include_expansions_in_env: ["is_patch", "project", "version_id", "AUTH", "SSL", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "SETDEFAULTENCODING", "test_loadbalancer", "test_serverless", "SKIP_CSOT_TESTS", "MONGODB_STARTED", "DISABLE_TEST_COMMANDS", "GREEN_FRAMEWORK", "NO_EXT", "COVERAGE", "COMPRESSORS", "TEST_SUITES", "MONGODB_API_VERSION", "SKIP_HATCH", "skip_crypt_shared", "VERSION", "TOPOLOGY", "STORAGE_ENGINE", "ORCHESTRATION_FILE", "REQUIRE_API_VERSION", "LOAD_BALANCER", "skip_web_identity_auth_test", "skip_ECS_auth_test"] binary: bash working_dir: "src" args: @@ -52,19 +51,11 @@ functions: file: src/expansion.yml "prepare resources": - - command: shell.exec + - command: subprocess.exec params: - script: | - . src/.evergreen/scripts/env.sh - set -o xtrace - rm -rf $DRIVERS_TOOLS - if [ "$PROJECT" = "drivers-tools" ]; then - # If this was a patch build, doing a fresh clone would not actually test the patch - cp -R ${PROJECT_DIRECTORY}/ ${DRIVERS_TOOLS} - else - git clone https://github.com/mongodb-labs/drivers-evergreen-tools.git ${DRIVERS_TOOLS} - fi - echo "{ \"releases\": { \"default\": \"$MONGODB_BINARIES\" }}" > $MONGO_ORCHESTRATION_HOME/orchestration.config + binary: bash + args: + - src/.evergreen/scripts/prepare-resources.sh "upload coverage" : - command: ec2.assume_role @@ -88,14 +79,17 @@ functions: - command: ec2.assume_role params: role_arn: ${assume_role_arn} - - command: shell.exec + - command: subprocess.exec params: silent: true + binary: bash working_dir: "src" include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] - script: | - # Download all the task coverage files. - aws s3 cp --recursive s3://${bucket_name}/coverage/${revision}/${version_id}/coverage/ coverage/ + args: + - .evergreen/scripts/download-and-merge-coverage.sh + - ${bucket_name} + - ${revision} + - ${version_id} - command: subprocess.exec params: working_dir: "src" @@ -103,13 +97,17 @@ functions: args: - .evergreen/combine-coverage.sh # Upload the resulting html coverage report. - - command: shell.exec + - command: subprocess.exec params: silent: true + binary: bash working_dir: "src" include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] - script: | - aws s3 cp htmlcov/ s3://${bucket_name}/coverage/${revision}/${version_id}/htmlcov/ --recursive --acl public-read --region us-east-1 + args: + - .evergreen/scripts/upload-coverage-report.sh + - ${bucket_name} + - ${revision} + - ${version_id} # Attach the index.html with s3.put so it shows up in the Evergreen UI. - command: s3.put params: @@ -128,15 +126,6 @@ functions: - command: ec2.assume_role params: role_arn: ${assume_role_arn} - - command: shell.exec - params: - script: | - . src/.evergreen/scripts/env.sh - set -o xtrace - mkdir out_dir - find $MONGO_ORCHESTRATION_HOME -name \*.log -exec sh -c 'x="{}"; mv $x $PWD/out_dir/$(basename $(dirname $x))_$(basename $x)' \; - tar zcvf mongodb-logs.tar.gz -C out_dir/ . - rm -rf out_dir - command: archive.targz_pack params: target: "mongo-coredumps.tgz" @@ -161,23 +150,12 @@ functions: aws_key: ${AWS_ACCESS_KEY_ID} aws_secret: ${AWS_SECRET_ACCESS_KEY} aws_session_token: ${AWS_SESSION_TOKEN} - local_file: mongodb-logs.tar.gz - remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-mongodb-logs.tar.gz + local_file: ${DRIVERS_TOOLS}/.evergreen/test_logs.tar.gz + remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-drivers-tools-logs.tar.gz bucket: ${bucket_name} permissions: public-read content_type: ${content_type|application/x-gzip} - display_name: "mongodb-logs.tar.gz" - - command: s3.put - params: - aws_key: ${AWS_ACCESS_KEY_ID} - aws_secret: ${AWS_SECRET_ACCESS_KEY} - aws_session_token: ${AWS_SESSION_TOKEN} - local_file: drivers-tools/.evergreen/orchestration/server.log - remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-orchestration.log - bucket: ${bucket_name} - permissions: public-read - content_type: ${content_type|text/plain} - display_name: "orchestration.log" + display_name: "drivers-tools-logs.tar.gz" "upload working dir": - command: ec2.assume_role @@ -230,54 +208,13 @@ functions: file: "src/xunit-results/TEST-*.xml" "bootstrap mongo-orchestration": - - command: shell.exec + - command: subprocess.exec params: - script: | - . src/.evergreen/scripts/env.sh - set -o xtrace - - # Enable core dumps if enabled on the machine - # Copied from https://github.com/mongodb/mongo/blob/master/etc/evergreen.yml - if [ -f /proc/self/coredump_filter ]; then - # Set the shell process (and its children processes) to dump ELF headers (bit 4), - # anonymous shared mappings (bit 1), and anonymous private mappings (bit 0). - echo 0x13 > /proc/self/coredump_filter - - if [ -f /sbin/sysctl ]; then - # Check that the core pattern is set explicitly on our distro image instead - # of being the OS's default value. This ensures that coredump names are consistent - # across distros and can be picked up by Evergreen. - core_pattern=$(/sbin/sysctl -n "kernel.core_pattern") - if [ "$core_pattern" = "dump_%e.%p.core" ]; then - echo "Enabling coredumps" - ulimit -c unlimited - fi - fi - fi - - if [ $(uname -s) = "Darwin" ]; then - core_pattern_mac=$(/usr/sbin/sysctl -n "kern.corefile") - if [ "$core_pattern_mac" = "dump_%N.%P.core" ]; then - echo "Enabling coredumps" - ulimit -c unlimited - fi - fi - - if [ -n "${skip_crypt_shared}" ]; then - export SKIP_CRYPT_SHARED=1 - fi - - MONGODB_VERSION=${VERSION} \ - TOPOLOGY=${TOPOLOGY} \ - AUTH=${AUTH} \ - SSL=${SSL} \ - STORAGE_ENGINE=${STORAGE_ENGINE} \ - DISABLE_TEST_COMMANDS=${DISABLE_TEST_COMMANDS} \ - ORCHESTRATION_FILE=${ORCHESTRATION_FILE} \ - REQUIRE_API_VERSION=${REQUIRE_API_VERSION} \ - LOAD_BALANCER=${LOAD_BALANCER} \ - bash ${DRIVERS_TOOLS}/.evergreen/run-orchestration.sh - # run-orchestration generates expansion file with the MONGODB_URI for the cluster + binary: bash + include_expansions_in_env: ["VERSION", "TOPOLOGY", "AUTH", "SSL", "ORCHESTRATION_FILE", "LOAD_BALANCER"] + args: + - src/.evergreen/scripts/run-with-env.sh + - src/.evergreen/scripts/bootstrap-mongo-orchestration.sh - command: expansions.update params: file: mo-expansion.yml @@ -288,167 +225,107 @@ functions: value: "1" "bootstrap data lake": - - command: shell.exec + - command: subprocess.exec type: setup params: - script: | - . src/.evergreen/scripts/env.sh - bash ${DRIVERS_TOOLS}/.evergreen/atlas_data_lake/pull-mongohouse-image.sh - - command: shell.exec + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/atlas_data_lake/pull-mongohouse-image.sh + - command: subprocess.exec type: setup params: - script: | - . src/.evergreen/scripts/env.sh - bash ${DRIVERS_TOOLS}/.evergreen/atlas_data_lake/run-mongohouse-image.sh - sleep 1 - docker ps + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/atlas_data_lake/run-mongohouse-image.sh "stop mongo-orchestration": - - command: shell.exec + - command: subprocess.exec params: - script: | - . src/.evergreen/scripts/env.sh - set -o xtrace - bash ${DRIVERS_TOOLS}/.evergreen/stop-orchestration.sh + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/stop-orchestration.sh "run mod_wsgi tests": - - command: shell.exec + - command: subprocess.exec type: test params: + include_expansions_in_env: [MOD_WSGI_VERSION, MOD_WSGI_EMBEDDED, "PYTHON_BINARY"] working_dir: "src" - script: | - . .evergreen/scripts/env.sh - set -o xtrace - PYTHON_BINARY=${PYTHON_BINARY} MOD_WSGI_VERSION=${MOD_WSGI_VERSION} \ - MOD_WSGI_EMBEDDED=${MOD_WSGI_EMBEDDED} PROJECT_DIRECTORY=${PROJECT_DIRECTORY} \ - bash ${PROJECT_DIRECTORY}/.evergreen/run-mod-wsgi-tests.sh + binary: bash + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-mod-wsgi-tests.sh "run mockupdb tests": - - command: shell.exec + - command: subprocess.exec type: test params: + include_expansions_in_env: ["PYTHON_BINARY"] working_dir: "src" - script: | - . .evergreen/scripts/env.sh - set -o xtrace - export PYTHON_BINARY=${PYTHON_BINARY} - bash ${PROJECT_DIRECTORY}/.evergreen/hatch.sh test:test-mockupdb + binary: bash + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-mockupdb-tests.sh "run doctests": - - command: shell.exec + - command: subprocess.exec type: test params: + include_expansions_in_env: [ "PYTHON_BINARY" ] working_dir: "src" - script: | - . .evergreen/scripts/env.sh - set -o xtrace - PYTHON_BINARY=${PYTHON_BINARY} bash ${PROJECT_DIRECTORY}/.evergreen/hatch.sh doctest:test + binary: bash + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-doctests.sh "run tests": - - command: shell.exec + - command: subprocess.exec params: + include_expansions_in_env: ["TEST_DATA_LAKE", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"] + binary: bash working_dir: "src" - shell: bash - background: true - include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] - script: | - . .evergreen/scripts/env.sh - if [ -n "${test_encryption}" ]; then - ./.evergreen/hatch.sh encryption:setup - fi - - command: shell.exec + args: + - .evergreen/scripts/setup-tests.sh + - command: subprocess.exec + params: + working_dir: "src" + binary: bash + background: true + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/setup-encryption.sh + - command: subprocess.exec type: test params: working_dir: "src" - shell: bash - script: | - # Disable xtrace - set +x - . .evergreen/scripts/env.sh - if [ -n "${MONGODB_STARTED}" ]; then - export PYMONGO_MUST_CONNECT=true - fi - if [ -n "${DISABLE_TEST_COMMANDS}" ]; then - export PYMONGO_DISABLE_TEST_COMMANDS=1 - fi - if [ -n "${test_encryption}" ]; then - # Disable xtrace (just in case it was accidentally set). - set +x - bash ${DRIVERS_TOOLS}/.evergreen/csfle/await-servers.sh - export TEST_ENCRYPTION=1 - if [ -n "${test_encryption_pyopenssl}" ]; then - export TEST_ENCRYPTION_PYOPENSSL=1 - fi - fi - if [ -n "${test_crypt_shared}" ]; then - export TEST_CRYPT_SHARED=1 - export CRYPT_SHARED_LIB_PATH=${CRYPT_SHARED_LIB_PATH} - fi - if [ -n "${test_pyopenssl}" ]; then - export TEST_PYOPENSSL=1 - fi - if [ -n "${SETDEFAULTENCODING}" ]; then - export SETDEFAULTENCODING="${SETDEFAULTENCODING}" - fi - if [ -n "${test_loadbalancer}" ]; then - export TEST_LOADBALANCER=1 - export SINGLE_MONGOS_LB_URI="${SINGLE_MONGOS_LB_URI}" - export MULTI_MONGOS_LB_URI="${MULTI_MONGOS_LB_URI}" - fi - if [ -n "${test_serverless}" ]; then - export TEST_SERVERLESS=1 - fi - if [ -n "${TEST_INDEX_MANAGEMENT}" ]; then - export TEST_INDEX_MANAGEMENT=1 - fi - if [ -n "${SKIP_CSOT_TESTS}" ]; then - export SKIP_CSOT_TESTS=1 - fi - - GREEN_FRAMEWORK=${GREEN_FRAMEWORK} \ - PYTHON_BINARY=${PYTHON_BINARY} \ - NO_EXT=${NO_EXT} \ - COVERAGE=${COVERAGE} \ - COMPRESSORS=${COMPRESSORS} \ - AUTH=${AUTH} \ - SSL=${SSL} \ - TEST_DATA_LAKE=${TEST_DATA_LAKE} \ - TEST_SUITES=${TEST_SUITES} \ - MONGODB_API_VERSION=${MONGODB_API_VERSION} \ - SKIP_HATCH=${SKIP_HATCH} \ - bash ${PROJECT_DIRECTORY}/.evergreen/hatch.sh test:test-eg + binary: bash + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN", "PYTHON_BINARY", "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "SINGLE_MONGOS_LB_URI", "MULTI_MONGOS_LB_URI", "TEST_SUITES"] + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-tests.sh "run enterprise auth tests": - - command: shell.exec + - command: subprocess.exec type: test params: + binary: bash working_dir: "src" - include_expansions_in_env: ["DRIVERS_TOOLS", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] - script: | - # Disable xtrace for security reasons (just in case it was accidentally set). - set +x - bash ${DRIVERS_TOOLS}/.evergreen/auth_aws/setup_secrets.sh drivers/enterprise_auth - PROJECT_DIRECTORY="${PROJECT_DIRECTORY}" \ - PYTHON_BINARY="${PYTHON_BINARY}" \ - TEST_ENTERPRISE_AUTH=1 \ - AUTH=auth \ - bash ${PROJECT_DIRECTORY}/.evergreen/hatch.sh test:test-eg + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN", "PYTHON_BINARY"] + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-enterprise-auth-tests.sh "run atlas tests": - - command: shell.exec + - command: subprocess.exec type: test params: - include_expansions_in_env: ["DRIVERS_TOOLS", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + binary: bash + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN", "PYTHON_BINARY"] working_dir: "src" - script: | - # Disable xtrace for security reasons (just in case it was accidentally set). - set +x - set -o errexit - bash ${DRIVERS_TOOLS}/.evergreen/auth_aws/setup_secrets.sh drivers/atlas_connect - PROJECT_DIRECTORY="${PROJECT_DIRECTORY}" \ - PYTHON_BINARY="${PYTHON_BINARY}" \ - TEST_ATLAS=1 \ - bash ${PROJECT_DIRECTORY}/.evergreen/hatch.sh test:test-eg + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-atlas-tests.sh "get aws auth secrets": - command: subprocess.exec @@ -460,57 +337,140 @@ functions: - ${DRIVERS_TOOLS}/.evergreen/auth_aws/setup-secrets.sh "run aws auth test with regular aws credentials": - - command: shell.exec + - command: subprocess.exec + params: + include_expansions_in_env: ["TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"] + binary: bash + working_dir: "src" + args: + - .evergreen/scripts/setup-tests.sh + - command: subprocess.exec type: test params: - shell: "bash" + include_expansions_in_env: ["DRIVERS_TOOLS", "skip_EC2_auth_test"] + binary: bash working_dir: "src" - script: | - . .evergreen/scripts/env.sh - .evergreen/run-mongodb-aws-test.sh regular + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-mongodb-aws-test.sh + - regular "run aws auth test with assume role credentials": - - command: shell.exec + - command: subprocess.exec + params: + include_expansions_in_env: [ "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE" ] + binary: bash + working_dir: "src" + args: + - .evergreen/scripts/setup-tests.sh + - command: subprocess.exec type: test params: - shell: "bash" + include_expansions_in_env: ["DRIVERS_TOOLS", "skip_EC2_auth_test"] + binary: bash working_dir: "src" - script: | - . .evergreen/scripts/env.sh - .evergreen/run-mongodb-aws-test.sh assume-role + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-mongodb-aws-test.sh + - assume-role "run aws auth test with aws EC2 credentials": - - command: shell.exec + - command: subprocess.exec + params: + include_expansions_in_env: [ "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE" ] + binary: bash + working_dir: "src" + args: + - .evergreen/scripts/setup-tests.sh + - command: subprocess.exec type: test params: + include_expansions_in_env: ["DRIVERS_TOOLS", "skip_EC2_auth_test"] + binary: bash working_dir: "src" - shell: "bash" - script: | - if [ "${skip_EC2_auth_test}" = "true" ]; then - echo "This platform does not support the EC2 auth test, skipping..." - exit 0 - fi - . .evergreen/scripts/env.sh - .evergreen/run-mongodb-aws-test.sh ec2 + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-mongodb-aws-test.sh + - ec2 "run aws auth test with aws web identity credentials": - - command: shell.exec + - command: subprocess.exec + params: + include_expansions_in_env: [ "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE" ] + binary: bash + working_dir: "src" + args: + - .evergreen/scripts/setup-tests.sh + - # Test with and without AWS_ROLE_SESSION_NAME set. + - command: subprocess.exec type: test params: + include_expansions_in_env: ["DRIVERS_TOOLS", "skip_EC2_auth_test"] + binary: bash working_dir: "src" - shell: "bash" - script: | - if [ "${skip_EC2_auth_test}" = "true" ]; then - echo "This platform does not support the web identity auth test, skipping..." - exit 0 - fi - . .evergreen/scripts/env.sh - # Test with and without AWS_ROLE_SESSION_NAME set. - .evergreen/run-mongodb-aws-test.sh web-identity - AWS_ROLE_SESSION_NAME="test" \ - .evergreen/run-mongodb-aws-test.sh web-identity + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-mongodb-aws-test.sh + - web-identity + - command: subprocess.exec + type: test + params: + include_expansions_in_env: [ "DRIVERS_TOOLS", "skip_EC2_auth_test" ] + binary: bash + working_dir: "src" + env: + AWS_ROLE_SESSION_NAME: test + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-mongodb-aws-test.sh + - web-identity + + "run aws auth test with aws credentials as environment variables": + - command: subprocess.exec + params: + include_expansions_in_env: [ "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE" ] + binary: bash + working_dir: "src" + args: + - .evergreen/scripts/setup-tests.sh + - command: subprocess.exec + type: test + params: + include_expansions_in_env: ["DRIVERS_TOOLS", "skip_EC2_auth_test"] + binary: bash + working_dir: "src" + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-mongodb-aws-test.sh + - env-creds + + "run aws auth test with aws credentials and session token as environment variables": + - command: subprocess.exec + params: + include_expansions_in_env: [ "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE" ] + binary: bash + working_dir: "src" + args: + - .evergreen/scripts/setup-tests.sh + - command: subprocess.exec + type: test + params: + include_expansions_in_env: ["DRIVERS_TOOLS", "skip_EC2_auth_test"] + binary: bash + working_dir: "src" + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-mongodb-aws-test.sh + - session-creds "run oidc auth test with test credentials": + - command: subprocess.exec + params: + include_expansions_in_env: [ "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE" ] + binary: bash + working_dir: "src" + args: + - .evergreen/scripts/setup-tests.sh - command: subprocess.exec type: test params: @@ -532,110 +492,69 @@ functions: args: - ${PROJECT_DIRECTORY}/.evergreen/run-mongodb-oidc-remote-test.sh - "run aws auth test with aws credentials as environment variables": - - command: shell.exec - type: test - params: - working_dir: "src" - shell: bash - script: | - . .evergreen/scripts/env.sh - .evergreen/run-mongodb-aws-test.sh env-creds - - "run aws auth test with aws credentials and session token as environment variables": - - command: shell.exec - type: test - params: - working_dir: "src" - shell: bash - script: | - . .evergreen/scripts/env.sh - .evergreen/run-mongodb-aws-test.sh session-creds - "run aws ECS auth test": - - command: shell.exec + - command: subprocess.exec type: test params: - shell: "bash" + binary: bash working_dir: "src" - script: | - if [ "${skip_ECS_auth_test}" = "true" ]; then - echo "This platform does not support the ECS auth test, skipping..." - exit 0 - fi - . .evergreen/scripts/env.sh - set -ex - cd ${DRIVERS_TOOLS}/.evergreen/auth_aws - . ./activate-authawsvenv.sh - . aws_setup.sh ecs - export MONGODB_BINARIES="$MONGODB_BINARIES"; - export PROJECT_DIRECTORY="${PROJECT_DIRECTORY}"; - python aws_tester.py ecs - cd - + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-aws-ecs-auth-test.sh "cleanup": - - command: shell.exec + - command: subprocess.exec params: + binary: bash working_dir: "src" - script: | - . .evergreen/scripts/env.sh - if [ -f $DRIVERS_TOOLS/.evergreen/csfle/secrets-export.sh ]; then - . .evergreen/hatch.sh encryption:teardown - fi - rm -rf ${DRIVERS_TOOLS} || true - rm -f ./secrets-export.sh || true + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/cleanup.sh + + "teardown": + - command: subprocess.exec + params: + binary: bash + working_dir: "src" + args: + - ${DRIVERS_TOOLS}/.evergreen/teardown.sh "fix absolute paths": - - command: shell.exec + - command: subprocess.exec params: - script: | - set +x - . src/.evergreen/scripts/env.sh - for filename in $(find ${DRIVERS_TOOLS} -name \*.json); do - perl -p -i -e "s|ABSOLUTE_PATH_REPLACEMENT_TOKEN|${DRIVERS_TOOLS}|g" $filename - done + binary: bash + args: + - src/.evergreen/scripts/fix-absolute-paths.sh "windows fix": - - command: shell.exec + - command: subprocess.exec params: - script: | - set +x - . src/.evergreen/scripts/env.sh - for i in $(find ${DRIVERS_TOOLS}/.evergreen ${PROJECT_DIRECTORY}/.evergreen -name \*.sh); do - cat $i | tr -d '\r' > $i.new - mv $i.new $i - done - # Copy client certificate because symlinks do not work on Windows. - cp ${DRIVERS_TOOLS}/.evergreen/x509gen/client.pem $MONGO_ORCHESTRATION_HOME/lib/client.pem + binary: bash + args: + - src/.evergreen/scripts/windows-fix.sh "make files executable": - - command: shell.exec + - command: subprocess.exec params: - script: | - set +x - . src/.evergreen/scripts/env.sh - for i in $(find ${DRIVERS_TOOLS}/.evergreen ${PROJECT_DIRECTORY}/.evergreen -name \*.sh); do - chmod +x $i - done + binary: bash + args: + - src/.evergreen/scripts/make-files-executable.sh "init test-results": - - command: shell.exec + - command: subprocess.exec params: - script: | - set +x - . src/.evergreen/scripts/env.sh - echo '{"results": [{ "status": "FAIL", "test_file": "Build", "log_raw": "No test-results.json found was created" } ]}' > ${PROJECT_DIRECTORY}/test-results.json + binary: bash + args: + - src/.evergreen/scripts/init-test-results.sh "install dependencies": - - command: shell.exec + - command: subprocess.exec params: + binary: bash working_dir: "src" - script: | - . .evergreen/scripts/env.sh - set -o xtrace - file="${PROJECT_DIRECTORY}/.evergreen/install-dependencies.sh" - # Don't use ${file} syntax here because evergreen treats it as an empty expansion. - [ -f "$file" ] && bash $file || echo "$file not available, skipping" + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/install-dependencies.sh "assume ec2 role": - command: ec2.assume_role @@ -657,18 +576,22 @@ functions: file: atlas-expansion.yml "run-ocsp-test": - - command: shell.exec + - command: subprocess.exec + params: + include_expansions_in_env: [ "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE" ] + binary: bash + working_dir: "src" + args: + - .evergreen/scripts/setup-tests.sh + - command: subprocess.exec type: test params: + include_expansions_in_env: ["OCSP_ALGORITHM", "OCSP_TLS_SHOULD_SUCCEED", "PYTHON_BINARY"] + binary: bash working_dir: "src" - script: | - . .evergreen/scripts/env.sh - TEST_OCSP=1 \ - PYTHON_BINARY=${PYTHON_BINARY} \ - CA_FILE="${DRIVERS_TOOLS}/.evergreen/ocsp/${OCSP_ALGORITHM}/ca.pem" \ - OCSP_TLS_SHOULD_SUCCEED="${OCSP_TLS_SHOULD_SUCCEED}" \ - bash ${PROJECT_DIRECTORY}/.evergreen/hatch.sh test:test-eg - bash ${DRIVERS_TOOLS}/.evergreen/ocsp/teardown.sh + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-ocsp-test.sh "run-ocsp-server": - command: subprocess.exec @@ -680,42 +603,38 @@ functions: - ${DRIVERS_TOOLS}/.evergreen/ocsp/setup.sh "run load-balancer": - - command: shell.exec + - command: subprocess.exec params: - script: | - DRIVERS_TOOLS=${DRIVERS_TOOLS} MONGODB_URI=${MONGODB_URI} bash ${DRIVERS_TOOLS}/.evergreen/run-load-balancer.sh start + binary: bash + include_expansions_in_env: ["MONGODB_URI"] + args: + - src/.evergreen/scripts/run-with-env.sh + - src/.evergreen/scripts/run-load-balancer.sh - command: expansions.update params: file: lb-expansion.yml "stop load-balancer": - - command: shell.exec + - command: subprocess.exec params: - script: | - cd ${DRIVERS_TOOLS}/.evergreen - DRIVERS_TOOLS=${DRIVERS_TOOLS} bash ${DRIVERS_TOOLS}/.evergreen/run-load-balancer.sh stop + binary: bash + args: + - src/.evergreen/scripts/stop-load-balancer.sh "teardown_docker": - - command: shell.exec + - command: subprocess.exec params: - script: | - # Remove all Docker images - DOCKER=$(command -v docker) || true - if [ -n "$DOCKER" ]; then - docker rmi -f $(docker images -a -q) &> /dev/null || true - fi + binary: bash + args: + - src/.evergreen/scripts/teardown-docker.sh "teardown_aws": - - command: shell.exec + - command: subprocess.exec params: - shell: "bash" - script: | - . src/.evergreen/scripts/env.sh - cd "${DRIVERS_TOOLS}/.evergreen/auth_aws" - if [ -f "./aws_e2e_setup.json" ]; then - . ./activate-authawsvenv.sh - python ./lib/aws_assign_instance_profile.py - fi + binary: bash + args: + - src/.evergreen/scripts/run-with-env.sh + - src/.evergreen/scripts/teardown-aws.sh "teardown atlas": - command: subprocess.exec @@ -725,13 +644,14 @@ functions: - ${DRIVERS_TOOLS}/.evergreen/atlas/teardown-atlas-cluster.sh "run perf tests": - - command: shell.exec + - command: subprocess.exec type: test params: working_dir: "src" - script: | - . .evergreen/scripts/env.sh - PROJECT_DIRECTORY=${PROJECT_DIRECTORY} bash ${PROJECT_DIRECTORY}/.evergreen/run-perf-tests.sh + binary: bash + args: + - .evergreen/scripts/run-with-env.sh + - .evergreen/scripts/run-perf-tests.sh "attach benchmark test results": - command: attach.results @@ -756,6 +676,7 @@ pre: post: # Disabled, causing timeouts # - func: "upload working dir" + - func: "teardown" - func: "upload coverage" - func: "upload mo artifacts" - func: "upload test results" @@ -798,13 +719,13 @@ task_groups: - func: make files executable - command: subprocess.exec params: - binary: "bash" + binary: bash args: - ${DRIVERS_TOOLS}/.evergreen/csfle/gcpkms/create-and-setup-instance.sh teardown_task: - command: subprocess.exec params: - binary: "bash" + binary: bash args: - ${DRIVERS_TOOLS}/.evergreen/csfle/gcpkms/delete-instance.sh - func: "upload test results" @@ -966,31 +887,12 @@ tasks: # Throw it here, and execute this task on all buildvariants - name: getdata commands: - - command: shell.exec + - command: subprocess.exec + binary: bash type: test params: - script: | - set -o xtrace - . ${DRIVERS_TOOLS}/.evergreen/download-mongodb.sh || true - get_distro || true - echo $DISTRO - echo $MARCH - echo $OS - uname -a || true - ls /etc/*release* || true - cc --version || true - gcc --version || true - clang --version || true - gcov --version || true - lcov --version || true - llvm-cov --version || true - echo $PATH - ls -la /usr/local/Cellar/llvm/*/bin/ || true - ls -la /usr/local/Cellar/ || true - scan-build --version || true - genhtml --version || true - valgrind --version || true - + args: + - src/.evergreen/scripts/run-getdata.sh # Standard test tasks {{{ - name: "mockupdb" @@ -1647,7 +1549,7 @@ tasks: type: setup params: working_dir: "src" - binary: "bash" + binary: bash include_expansions_in_env: ["DRIVERS_TOOLS"] args: - .evergreen/run-gcpkms-test.sh @@ -1660,17 +1562,14 @@ tasks: vars: VERSION: "latest" TOPOLOGY: "server" - - command: shell.exec + - command: subprocess.exec type: test params: + include_expansions_in_env: ["PYTHON_BINARY"] working_dir: "src" - shell: "bash" - script: | - . .evergreen/scripts/env.sh - export PYTHON_BINARY=/opt/mongodbtoolchain/v4/bin/python3 - export LIBMONGOCRYPT_URL=https://s3.amazonaws.com/mciuploads/libmongocrypt/debian11/master/latest/libmongocrypt.tar.gz - SKIP_SERVERS=1 bash ./.evergreen/setup-encryption.sh - SUCCESS=false TEST_FLE_GCP_AUTO=1 ./.evergreen/hatch.sh test:test-eg + binary: "bash" + args: + - .evergreen/scripts/run-gcpkms-fail-test.sh - name: testazurekms-task commands: @@ -1736,18 +1635,15 @@ tasks: - name: "check-import-time" tags: ["pr"] commands: - - command: shell.exec + - command: subprocess.exec type: test params: - shell: "bash" + binary: bash working_dir: src - script: | - . .evergreen/scripts/env.sh - set -x - export BASE_SHA=${revision} - export HEAD_SHA=${github_commit} - bash .evergreen/run-import-time-test.sh - + args: + - .evergreen/scripts/check-import-time.sh + - ${revision} + - ${github_commit} - name: "backport-pr" allowed_requesters: ["commit"] commands: diff --git a/.evergreen/hatch.sh b/.evergreen/hatch.sh index 45d5113cd..98cd9ed73 100644 --- a/.evergreen/hatch.sh +++ b/.evergreen/hatch.sh @@ -29,7 +29,7 @@ else # Set up virtualenv before installing hatch # Ensure hatch does not write to user or global locations. touch hatch_config.toml HATCH_CONFIG=$(pwd)/hatch_config.toml - if [ "Windows_NT" = "$OS" ]; then # Magic variable in cygwin + if [ "Windows_NT" = "${OS:-}" ]; then # Magic variable in cygwin HATCH_CONFIG=$(cygpath -m "$HATCH_CONFIG") fi export HATCH_CONFIG diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 1e03e2714..9716c1fc7 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -53,18 +53,18 @@ if [ -z "${NO_EXT:-}" ]; then fi if [ "$AUTH" != "noauth" ]; then - if [ ! -z "$TEST_DATA_LAKE" ]; then + if [ -n "$TEST_DATA_LAKE" ]; then export DB_USER="mhuser" export DB_PASSWORD="pencil" - elif [ ! -z "$TEST_SERVERLESS" ]; then - source ${DRIVERS_TOOLS}/.evergreen/serverless/secrets-export.sh + elif [ -n "$TEST_SERVERLESS" ]; then + source "${DRIVERS_TOOLS}"/.evergreen/serverless/secrets-export.sh export DB_USER=$SERVERLESS_ATLAS_USER export DB_PASSWORD=$SERVERLESS_ATLAS_PASSWORD export MONGODB_URI="$SERVERLESS_URI" echo "MONGODB_URI=$MONGODB_URI" export SINGLE_MONGOS_LB_URI=$MONGODB_URI export MULTI_MONGOS_LB_URI=$MONGODB_URI - elif [ ! -z "$TEST_AUTH_OIDC" ]; then + elif [ -n "$TEST_AUTH_OIDC" ]; then export DB_USER=$OIDC_ADMIN_USER export DB_PASSWORD=$OIDC_ADMIN_PWD export DB_IP="$MONGODB_URI" diff --git a/.evergreen/scripts/archive-mongodb-logs.sh b/.evergreen/scripts/archive-mongodb-logs.sh new file mode 100644 index 000000000..70a337cd1 --- /dev/null +++ b/.evergreen/scripts/archive-mongodb-logs.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +set -o xtrace +mkdir out_dir +# shellcheck disable=SC2156 +find "$MONGO_ORCHESTRATION_HOME" -name \*.log -exec sh -c 'x="{}"; mv $x $PWD/out_dir/$(basename $(dirname $x))_$(basename $x)' \; +tar zcvf mongodb-logs.tar.gz -C out_dir/ . +rm -rf out_dir diff --git a/.evergreen/scripts/bootstrap-mongo-orchestration.sh b/.evergreen/scripts/bootstrap-mongo-orchestration.sh new file mode 100644 index 000000000..1d2b145de --- /dev/null +++ b/.evergreen/scripts/bootstrap-mongo-orchestration.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +set -o xtrace + +# Enable core dumps if enabled on the machine +# Copied from https://github.com/mongodb/mongo/blob/master/etc/evergreen.yml +if [ -f /proc/self/coredump_filter ]; then + # Set the shell process (and its children processes) to dump ELF headers (bit 4), + # anonymous shared mappings (bit 1), and anonymous private mappings (bit 0). + echo 0x13 >/proc/self/coredump_filter + + if [ -f /sbin/sysctl ]; then + # Check that the core pattern is set explicitly on our distro image instead + # of being the OS's default value. This ensures that coredump names are consistent + # across distros and can be picked up by Evergreen. + core_pattern=$(/sbin/sysctl -n "kernel.core_pattern") + if [ "$core_pattern" = "dump_%e.%p.core" ]; then + echo "Enabling coredumps" + ulimit -c unlimited + fi + fi +fi + +if [ "$(uname -s)" = "Darwin" ]; then + core_pattern_mac=$(/usr/sbin/sysctl -n "kern.corefile") + if [ "$core_pattern_mac" = "dump_%N.%P.core" ]; then + echo "Enabling coredumps" + ulimit -c unlimited + fi +fi + +if [ -n "${skip_crypt_shared}" ]; then + export SKIP_CRYPT_SHARED=1 +fi + +MONGODB_VERSION=${VERSION} \ + TOPOLOGY=${TOPOLOGY} \ + AUTH=${AUTH:-noauth} \ + SSL=${SSL:-nossl} \ + STORAGE_ENGINE=${STORAGE_ENGINE:-} \ + DISABLE_TEST_COMMANDS=${DISABLE_TEST_COMMANDS:-} \ + ORCHESTRATION_FILE=${ORCHESTRATION_FILE:-} \ + REQUIRE_API_VERSION=${REQUIRE_API_VERSION:-} \ + LOAD_BALANCER=${LOAD_BALANCER:-} \ + bash ${DRIVERS_TOOLS}/.evergreen/run-orchestration.sh +# run-orchestration generates expansion file with the MONGODB_URI for the cluster diff --git a/.evergreen/scripts/check-import-time.sh b/.evergreen/scripts/check-import-time.sh new file mode 100644 index 000000000..cdd2025d5 --- /dev/null +++ b/.evergreen/scripts/check-import-time.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +. .evergreen/scripts/env.sh +set -x +export BASE_SHA="$1" +export HEAD_SHA="$2" +bash .evergreen/run-import-time-test.sh diff --git a/.evergreen/scripts/cleanup.sh b/.evergreen/scripts/cleanup.sh new file mode 100644 index 000000000..9e583e4f1 --- /dev/null +++ b/.evergreen/scripts/cleanup.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +if [ -f "$DRIVERS_TOOLS"/.evergreen/csfle/secrets-export.sh ]; then + . .evergreen/hatch.sh encryption:teardown +fi +rm -rf "${DRIVERS_TOOLS}" || true +rm -f ./secrets-export.sh || true diff --git a/.evergreen/scripts/configure-env.sh b/.evergreen/scripts/configure-env.sh index 98d400037..3c0a0436d 100644 --- a/.evergreen/scripts/configure-env.sh +++ b/.evergreen/scripts/configure-env.sh @@ -1,4 +1,4 @@ -#!/bin/bash -ex +#!/bin/bash -eux # Get the current unique version of this checkout # shellcheck disable=SC2154 @@ -29,7 +29,7 @@ fi export MONGO_ORCHESTRATION_HOME="$DRIVERS_TOOLS/.evergreen/orchestration" export MONGODB_BINARIES="$DRIVERS_TOOLS/mongodb/bin" -cat < $SCRIPT_DIR/env.sh +cat < "$SCRIPT_DIR"/env.sh set -o errexit export PROJECT_DIRECTORY="$PROJECT_DIRECTORY" export CURRENT_VERSION="$CURRENT_VERSION" @@ -38,6 +38,21 @@ export DRIVERS_TOOLS="$DRIVERS_TOOLS" export MONGO_ORCHESTRATION_HOME="$MONGO_ORCHESTRATION_HOME" export MONGODB_BINARIES="$MONGODB_BINARIES" export PROJECT_DIRECTORY="$PROJECT_DIRECTORY" +export SETDEFAULTENCODING="${SETDEFAULTENCODING:-}" +export SKIP_CSOT_TESTS="${SKIP_CSOT_TESTS:-}" +export MONGODB_STARTED="${MONGODB_STARTED:-}" +export DISABLE_TEST_COMMANDS="${DISABLE_TEST_COMMANDS:-}" +export GREEN_FRAMEWORK="${GREEN_FRAMEWORK:-}" +export NO_EXT="${NO_EXT:-}" +export COVERAGE="${COVERAGE:-}" +export COMPRESSORS="${COMPRESSORS:-}" +export MONGODB_API_VERSION="${MONGODB_API_VERSION:-}" +export SKIP_HATCH="${SKIP_HATCH:-}" +export skip_crypt_shared="${skip_crypt_shared:-}" +export STORAGE_ENGINE="${STORAGE_ENGINE:-}" +export REQUIRE_API_VERSION="${REQUIRE_API_VERSION:-}" +export skip_web_identity_auth_test="${skip_web_identity_auth_test:-}" +export skip_ECS_auth_test="${skip_ECS_auth_test:-}" export TMPDIR="$MONGO_ORCHESTRATION_HOME/db" export PATH="$MONGODB_BINARIES:$PATH" diff --git a/.evergreen/scripts/download-and-merge-coverage.sh b/.evergreen/scripts/download-and-merge-coverage.sh new file mode 100644 index 000000000..808bb957e --- /dev/null +++ b/.evergreen/scripts/download-and-merge-coverage.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +# Download all the task coverage files. +aws s3 cp --recursive s3://"$1"/coverage/"$2"/"$3"/coverage/ coverage/ diff --git a/.evergreen/scripts/fix-absolute-paths.sh b/.evergreen/scripts/fix-absolute-paths.sh new file mode 100644 index 000000000..eb9433c67 --- /dev/null +++ b/.evergreen/scripts/fix-absolute-paths.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +set +x +. src/.evergreen/scripts/env.sh +# shellcheck disable=SC2044 +for filename in $(find $DRIVERS_TOOLS -name \*.json); do + perl -p -i -e "s|ABSOLUTE_PATH_REPLACEMENT_TOKEN|$DRIVERS_TOOLS|g" $filename +done diff --git a/.evergreen/scripts/init-test-results.sh b/.evergreen/scripts/init-test-results.sh new file mode 100644 index 000000000..666ac6062 --- /dev/null +++ b/.evergreen/scripts/init-test-results.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +set +x +. src/.evergreen/scripts/env.sh +echo '{"results": [{ "status": "FAIL", "test_file": "Build", "log_raw": "No test-results.json found was created" } ]}' >$PROJECT_DIRECTORY/test-results.json diff --git a/.evergreen/scripts/install-dependencies.sh b/.evergreen/scripts/install-dependencies.sh new file mode 100644 index 000000000..ebcc8f306 --- /dev/null +++ b/.evergreen/scripts/install-dependencies.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +set -o xtrace +file="$PROJECT_DIRECTORY/.evergreen/install-dependencies.sh" +# Don't use ${file} syntax here because evergreen treats it as an empty expansion. +[ -f "$file" ] && bash "$file" || echo "$file not available, skipping" diff --git a/.evergreen/scripts/make-files-executable.sh b/.evergreen/scripts/make-files-executable.sh new file mode 100644 index 000000000..806be7c59 --- /dev/null +++ b/.evergreen/scripts/make-files-executable.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +set +x +. src/.evergreen/scripts/env.sh +# shellcheck disable=SC2044 +for i in $(find "$DRIVERS_TOOLS"/.evergreen "$PROJECT_DIRECTORY"/.evergreen -name \*.sh); do + chmod +x "$i" +done diff --git a/.evergreen/scripts/prepare-resources.sh b/.evergreen/scripts/prepare-resources.sh new file mode 100644 index 000000000..33394b55f --- /dev/null +++ b/.evergreen/scripts/prepare-resources.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +. src/.evergreen/scripts/env.sh +set -o xtrace +rm -rf $DRIVERS_TOOLS +if [ "$PROJECT" = "drivers-tools" ]; then + # If this was a patch build, doing a fresh clone would not actually test the patch + cp -R $PROJECT_DIRECTORY/ $DRIVERS_TOOLS +else + git clone https://github.com/mongodb-labs/drivers-evergreen-tools.git $DRIVERS_TOOLS +fi +echo "{ \"releases\": { \"default\": \"$MONGODB_BINARIES\" }}" >$MONGO_ORCHESTRATION_HOME/orchestration.config diff --git a/.evergreen/scripts/run-atlas-tests.sh b/.evergreen/scripts/run-atlas-tests.sh new file mode 100644 index 000000000..98a19f047 --- /dev/null +++ b/.evergreen/scripts/run-atlas-tests.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +# Disable xtrace for security reasons (just in case it was accidentally set). +set +x +set -o errexit +bash "${DRIVERS_TOOLS}"/.evergreen/auth_aws/setup_secrets.sh drivers/atlas_connect +TEST_ATLAS=1 bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-eg diff --git a/.evergreen/scripts/run-aws-ecs-auth-test.sh b/.evergreen/scripts/run-aws-ecs-auth-test.sh new file mode 100644 index 000000000..787e0a710 --- /dev/null +++ b/.evergreen/scripts/run-aws-ecs-auth-test.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# shellcheck disable=SC2154 +if [ "${skip_ECS_auth_test}" = "true" ]; then + echo "This platform does not support the ECS auth test, skipping..." + exit 0 +fi +set -ex +cd "$DRIVERS_TOOLS"/.evergreen/auth_aws +. ./activate-authawsvenv.sh +. aws_setup.sh ecs +export MONGODB_BINARIES="$MONGODB_BINARIES" +export PROJECT_DIRECTORY="$PROJECT_DIRECTORY" +python aws_tester.py ecs +cd - diff --git a/.evergreen/scripts/run-doctests.sh b/.evergreen/scripts/run-doctests.sh new file mode 100644 index 000000000..f7215ad34 --- /dev/null +++ b/.evergreen/scripts/run-doctests.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +set -o xtrace +PYTHON_BINARY=${PYTHON_BINARY} bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh doctest:test diff --git a/.evergreen/scripts/run-enterprise-auth-tests.sh b/.evergreen/scripts/run-enterprise-auth-tests.sh new file mode 100644 index 000000000..31371ead4 --- /dev/null +++ b/.evergreen/scripts/run-enterprise-auth-tests.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +# Disable xtrace for security reasons (just in case it was accidentally set). +set +x +bash "${DRIVERS_TOOLS}"/.evergreen/auth_aws/setup_secrets.sh drivers/enterprise_auth +TEST_ENTERPRISE_AUTH=1 AUTH=auth bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-eg diff --git a/.evergreen/scripts/run-gcpkms-fail-test.sh b/.evergreen/scripts/run-gcpkms-fail-test.sh new file mode 100644 index 000000000..dd9d522c8 --- /dev/null +++ b/.evergreen/scripts/run-gcpkms-fail-test.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +. .evergreen/scripts/env.sh +export PYTHON_BINARY=/opt/mongodbtoolchain/v4/bin/python3 +export LIBMONGOCRYPT_URL=https://s3.amazonaws.com/mciuploads/libmongocrypt/debian11/master/latest/libmongocrypt.tar.gz +SKIP_SERVERS=1 bash ./.evergreen/setup-encryption.sh +SUCCESS=false TEST_FLE_GCP_AUTO=1 ./.evergreen/hatch.sh test:test-eg diff --git a/.evergreen/scripts/run-getdata.sh b/.evergreen/scripts/run-getdata.sh new file mode 100644 index 000000000..b2d6ecb47 --- /dev/null +++ b/.evergreen/scripts/run-getdata.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +set -o xtrace +. ${DRIVERS_TOOLS}/.evergreen/download-mongodb.sh || true +get_distro || true +echo $DISTRO +echo $MARCH +echo $OS +uname -a || true +ls /etc/*release* || true +cc --version || true +gcc --version || true +clang --version || true +gcov --version || true +lcov --version || true +llvm-cov --version || true +echo $PATH +ls -la /usr/local/Cellar/llvm/*/bin/ || true +ls -la /usr/local/Cellar/ || true +scan-build --version || true +genhtml --version || true +valgrind --version || true diff --git a/.evergreen/scripts/run-load-balancer.sh b/.evergreen/scripts/run-load-balancer.sh new file mode 100644 index 000000000..7d431777e --- /dev/null +++ b/.evergreen/scripts/run-load-balancer.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +MONGODB_URI=${MONGODB_URI} bash "${DRIVERS_TOOLS}"/.evergreen/run-load-balancer.sh start diff --git a/.evergreen/scripts/run-mockupdb-tests.sh b/.evergreen/scripts/run-mockupdb-tests.sh new file mode 100644 index 000000000..8825a0237 --- /dev/null +++ b/.evergreen/scripts/run-mockupdb-tests.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +set -o xtrace +export PYTHON_BINARY=${PYTHON_BINARY} +bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-mockupdb diff --git a/.evergreen/run-mod-wsgi-tests.sh b/.evergreen/scripts/run-mod-wsgi-tests.sh similarity index 97% rename from .evergreen/run-mod-wsgi-tests.sh rename to .evergreen/scripts/run-mod-wsgi-tests.sh index e1f523811..607458b8c 100644 --- a/.evergreen/run-mod-wsgi-tests.sh +++ b/.evergreen/scripts/run-mod-wsgi-tests.sh @@ -28,7 +28,7 @@ export MOD_WSGI_SO=/opt/python/mod_wsgi/python_version/$PYTHON_VERSION/mod_wsgi_ export PYTHONHOME=/opt/python/$PYTHON_VERSION # If MOD_WSGI_EMBEDDED is set use the default embedded mode behavior instead # of daemon mode (WSGIDaemonProcess). -if [ -n "$MOD_WSGI_EMBEDDED" ]; then +if [ -n "${MOD_WSGI_EMBEDDED:-}" ]; then export MOD_WSGI_CONF=mod_wsgi_test_embedded.conf else export MOD_WSGI_CONF=mod_wsgi_test.conf diff --git a/.evergreen/run-mongodb-aws-test.sh b/.evergreen/scripts/run-mongodb-aws-test.sh similarity index 67% rename from .evergreen/run-mongodb-aws-test.sh rename to .evergreen/scripts/run-mongodb-aws-test.sh index c4051bb34..ec20bfd06 100755 --- a/.evergreen/run-mongodb-aws-test.sh +++ b/.evergreen/scripts/run-mongodb-aws-test.sh @@ -13,10 +13,16 @@ set -o errexit # Exit the script with error if any of the commands fail # mechanism. # PYTHON_BINARY The Python version to use. -echo "Running MONGODB-AWS authentication tests" +# shellcheck disable=SC2154 +if [ "${skip_EC2_auth_test:-}" = "true" ] && { [ "$1" = "ec2" ] || [ "$1" = "web-identity" ]; }; then + echo "This platform does not support the EC2 auth test, skipping..." + exit 0 +fi + +echo "Running MONGODB-AWS authentication tests for $1" # Handle credentials and environment setup. -. $DRIVERS_TOOLS/.evergreen/auth_aws/aws_setup.sh $1 +. "$DRIVERS_TOOLS"/.evergreen/auth_aws/aws_setup.sh "$1" # show test output set -x diff --git a/.evergreen/scripts/run-ocsp-test.sh b/.evergreen/scripts/run-ocsp-test.sh new file mode 100644 index 000000000..3c6d3b2b3 --- /dev/null +++ b/.evergreen/scripts/run-ocsp-test.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +TEST_OCSP=1 \ +PYTHON_BINARY="${PYTHON_BINARY}" \ +CA_FILE="${DRIVERS_TOOLS}/.evergreen/ocsp/${OCSP_ALGORITHM}/ca.pem" \ +OCSP_TLS_SHOULD_SUCCEED="${OCSP_TLS_SHOULD_SUCCEED}" \ +bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-eg +bash "${DRIVERS_TOOLS}"/.evergreen/ocsp/teardown.sh diff --git a/.evergreen/scripts/run-perf-tests.sh b/.evergreen/scripts/run-perf-tests.sh new file mode 100644 index 000000000..69a369fee --- /dev/null +++ b/.evergreen/scripts/run-perf-tests.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +PROJECT_DIRECTORY=${PROJECT_DIRECTORY} +bash "${PROJECT_DIRECTORY}"/.evergreen/run-perf-tests.sh diff --git a/.evergreen/scripts/run-tests.sh b/.evergreen/scripts/run-tests.sh new file mode 100644 index 000000000..495db83e7 --- /dev/null +++ b/.evergreen/scripts/run-tests.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +# Disable xtrace +set +x +if [ -n "${MONGODB_STARTED}" ]; then + export PYMONGO_MUST_CONNECT=true +fi +if [ -n "${DISABLE_TEST_COMMANDS}" ]; then + export PYMONGO_DISABLE_TEST_COMMANDS=1 +fi +if [ -n "${test_encryption}" ]; then + # Disable xtrace (just in case it was accidentally set). + set +x + bash "${DRIVERS_TOOLS}"/.evergreen/csfle/await-servers.sh + export TEST_ENCRYPTION=1 + if [ -n "${test_encryption_pyopenssl}" ]; then + export TEST_ENCRYPTION_PYOPENSSL=1 + fi +fi +if [ -n "${test_crypt_shared}" ]; then + export TEST_CRYPT_SHARED=1 + export CRYPT_SHARED_LIB_PATH=${CRYPT_SHARED_LIB_PATH} +fi +if [ -n "${test_pyopenssl}" ]; then + export TEST_PYOPENSSL=1 +fi +if [ -n "${SETDEFAULTENCODING}" ]; then + export SETDEFAULTENCODING="${SETDEFAULTENCODING}" +fi +if [ -n "${test_loadbalancer}" ]; then + export TEST_LOADBALANCER=1 + export SINGLE_MONGOS_LB_URI="${SINGLE_MONGOS_LB_URI}" + export MULTI_MONGOS_LB_URI="${MULTI_MONGOS_LB_URI}" +fi +if [ -n "${test_serverless}" ]; then + export TEST_SERVERLESS=1 +fi +if [ -n "${TEST_INDEX_MANAGEMENT:-}" ]; then + export TEST_INDEX_MANAGEMENT=1 +fi +if [ -n "${SKIP_CSOT_TESTS}" ]; then + export SKIP_CSOT_TESTS=1 +fi +GREEN_FRAMEWORK=${GREEN_FRAMEWORK} \ + PYTHON_BINARY=${PYTHON_BINARY} \ + NO_EXT=${NO_EXT} \ + COVERAGE=${COVERAGE} \ + COMPRESSORS=${COMPRESSORS} \ + AUTH=${AUTH} \ + SSL=${SSL} \ + TEST_DATA_LAKE=${TEST_DATA_LAKE:-} \ + TEST_SUITES=${TEST_SUITES:-} \ + MONGODB_API_VERSION=${MONGODB_API_VERSION} \ + SKIP_HATCH=${SKIP_HATCH} \ + bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-eg diff --git a/.evergreen/scripts/run-with-env.sh b/.evergreen/scripts/run-with-env.sh new file mode 100644 index 000000000..2fd073605 --- /dev/null +++ b/.evergreen/scripts/run-with-env.sh @@ -0,0 +1,21 @@ +#!/bin/bash -eu + +# Example use: bash run-with-env.sh run-tests.sh {args...} + +# Parameter expansion to get just the current directory's name +if [ "${PWD##*/}" == "src" ]; then + . .evergreen/scripts/env.sh + if [ -f ".evergreen/scripts/test-env.sh" ]; then + . .evergreen/scripts/test-env.sh + fi +else + . src/.evergreen/scripts/env.sh + if [ -f "src/.evergreen/scripts/test-env.sh" ]; then + . src/.evergreen/scripts/test-env.sh + fi +fi + +set -eu + +# shellcheck source=/dev/null +. "$@" diff --git a/.evergreen/scripts/setup-encryption.sh b/.evergreen/scripts/setup-encryption.sh new file mode 100644 index 000000000..2f167cd20 --- /dev/null +++ b/.evergreen/scripts/setup-encryption.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +if [ -n "${test_encryption}" ]; then + ./.evergreen/hatch.sh encryption:setup +fi diff --git a/.evergreen/scripts/setup-tests.sh b/.evergreen/scripts/setup-tests.sh new file mode 100644 index 000000000..65462b2a6 --- /dev/null +++ b/.evergreen/scripts/setup-tests.sh @@ -0,0 +1,27 @@ +#!/bin/bash -eux + +PROJECT_DIRECTORY="$(pwd)" +SCRIPT_DIR="$PROJECT_DIRECTORY/.evergreen/scripts" + +if [ -f "$SCRIPT_DIR/test-env.sh" ]; then + echo "Reading $SCRIPT_DIR/test-env.sh file" + . "$SCRIPT_DIR/test-env.sh" + exit 0 +fi + +cat < "$SCRIPT_DIR"/test-env.sh +export test_encryption="${test_encryption:-}" +export test_encryption_pyopenssl="${test_encryption_pyopenssl:-}" +export test_crypt_shared="${test_crypt_shared:-}" +export test_pyopenssl="${test_pyopenssl:-}" +export test_loadbalancer="${test_loadbalancer:-}" +export test_serverless="${test_serverless:-}" +export TEST_INDEX_MANAGEMENT="${TEST_INDEX_MANAGEMENT:-}" +export TEST_DATA_LAKE="${TEST_DATA_LAKE:-}" +export ORCHESTRATION_FILE="${ORCHESTRATION_FILE:-}" +export AUTH="${AUTH:-noauth}" +export SSL="${SSL:-nossl}" +export PYTHON_BINARY="${PYTHON_BINARY:-}" +EOT + +chmod +x "$SCRIPT_DIR"/test-env.sh diff --git a/.evergreen/scripts/stop-load-balancer.sh b/.evergreen/scripts/stop-load-balancer.sh new file mode 100644 index 000000000..2d3c5366e --- /dev/null +++ b/.evergreen/scripts/stop-load-balancer.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +cd "${DRIVERS_TOOLS}"/.evergreen || exit +DRIVERS_TOOLS=${DRIVERS_TOOLS} +bash "${DRIVERS_TOOLS}"/.evergreen/run-load-balancer.sh stop diff --git a/.evergreen/scripts/teardown-aws.sh b/.evergreen/scripts/teardown-aws.sh new file mode 100644 index 000000000..634d1e572 --- /dev/null +++ b/.evergreen/scripts/teardown-aws.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +cd "${DRIVERS_TOOLS}/.evergreen/auth_aws" || exit +if [ -f "./aws_e2e_setup.json" ]; then + . ./activate-authawsvenv.sh + python ./lib/aws_assign_instance_profile.py +fi diff --git a/.evergreen/scripts/teardown-docker.sh b/.evergreen/scripts/teardown-docker.sh new file mode 100644 index 000000000..733779d05 --- /dev/null +++ b/.evergreen/scripts/teardown-docker.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +# Remove all Docker images +DOCKER=$(command -v docker) || true +if [ -n "$DOCKER" ]; then + docker rmi -f "$(docker images -a -q)" &> /dev/null || true +fi diff --git a/.evergreen/scripts/upload-coverage-report.sh b/.evergreen/scripts/upload-coverage-report.sh new file mode 100644 index 000000000..71a2a80bb --- /dev/null +++ b/.evergreen/scripts/upload-coverage-report.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +aws s3 cp htmlcov/ s3://"$1"/coverage/"$2"/"$3"/htmlcov/ --recursive --acl public-read --region us-east-1 diff --git a/.evergreen/scripts/windows-fix.sh b/.evergreen/scripts/windows-fix.sh new file mode 100644 index 000000000..cb4fa4413 --- /dev/null +++ b/.evergreen/scripts/windows-fix.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +set +x +. src/.evergreen/scripts/env.sh +# shellcheck disable=SC2044 +for i in $(find "$DRIVERS_TOOLS"/.evergreen "$PROJECT_DIRECTORY"/.evergreen -name \*.sh); do + < "$i" tr -d '\r' >"$i".new + mv "$i".new "$i" +done +# Copy client certificate because symlinks do not work on Windows. +cp "$DRIVERS_TOOLS"/.evergreen/x509gen/client.pem "$MONGO_ORCHESTRATION_HOME"/lib/client.pem diff --git a/.evergreen/setup-encryption.sh b/.evergreen/setup-encryption.sh index 71231e173..b403ef9ca 100644 --- a/.evergreen/setup-encryption.sh +++ b/.evergreen/setup-encryption.sh @@ -52,6 +52,9 @@ ls -la libmongocrypt ls -la libmongocrypt/nocrypto if [ -z "${SKIP_SERVERS:-}" ]; then - bash ${DRIVERS_TOOLS}/.evergreen/csfle/setup-secrets.sh - bash ${DRIVERS_TOOLS}/.evergreen/csfle/start-servers.sh + PYTHON_BINARY_OLD=${PYTHON_BINARY} + export PYTHON_BINARY="" + bash "${DRIVERS_TOOLS}"/.evergreen/csfle/setup-secrets.sh + export PYTHON_BINARY=$PYTHON_BINARY_OLD + bash "${DRIVERS_TOOLS}"/.evergreen/csfle/start-servers.sh fi diff --git a/.evergreen/utils.sh b/.evergreen/utils.sh index d44425a90..908cf0564 100755 --- a/.evergreen/utils.sh +++ b/.evergreen/utils.sh @@ -17,7 +17,7 @@ find_python3() { elif [ -d "/Library/Frameworks/Python.Framework/Versions/3.9" ]; then PYTHON="/Library/Frameworks/Python.Framework/Versions/3.9/bin/python3" fi - elif [ "Windows_NT" = "$OS" ]; then # Magic variable in cygwin + elif [ "Windows_NT" = "${OS:-}" ]; then # Magic variable in cygwin PYTHON="C:/python/Python39/python.exe" else # Prefer our own toolchain, fall back to mongodb toolchain if it has Python 3.9+. @@ -56,7 +56,7 @@ createvirtualenv () { # Workaround for bug in older versions of virtualenv. $VIRTUALENV $VENVPATH 2>/dev/null || $VIRTUALENV $VENVPATH fi - if [ "Windows_NT" = "$OS" ]; then + if [ "Windows_NT" = "${OS:-}" ]; then # Workaround https://bugs.python.org/issue32451: # mongovenv/Scripts/activate: line 3: $'\r': command not found dos2unix $VENVPATH/Scripts/activate || true diff --git a/test/asynchronous/test_client_context.py b/test/asynchronous/test_client_context.py index 6d7781843..6a195eb6b 100644 --- a/test/asynchronous/test_client_context.py +++ b/test/asynchronous/test_client_context.py @@ -25,7 +25,7 @@ _IS_SYNC = False class TestAsyncClientContext(AsyncUnitTest): def test_must_connect(self): - if "PYMONGO_MUST_CONNECT" not in os.environ: + if not os.environ.get("PYMONGO_MUST_CONNECT"): raise SkipTest("PYMONGO_MUST_CONNECT is not set") self.assertTrue( @@ -37,7 +37,7 @@ class TestAsyncClientContext(AsyncUnitTest): ) def test_serverless(self): - if "TEST_SERVERLESS" not in os.environ: + if not os.environ.get("TEST_SERVERLESS"): raise SkipTest("TEST_SERVERLESS is not set") self.assertTrue( @@ -47,7 +47,7 @@ class TestAsyncClientContext(AsyncUnitTest): ) def test_enableTestCommands_is_disabled(self): - if "PYMONGO_DISABLE_TEST_COMMANDS" not in os.environ: + if not os.environ.get("PYMONGO_DISABLE_TEST_COMMANDS"): raise SkipTest("PYMONGO_DISABLE_TEST_COMMANDS is not set") self.assertFalse( @@ -56,7 +56,7 @@ class TestAsyncClientContext(AsyncUnitTest): ) def test_setdefaultencoding_worked(self): - if "SETDEFAULTENCODING" not in os.environ: + if not os.environ.get("SETDEFAULTENCODING"): raise SkipTest("SETDEFAULTENCODING is not set") self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"]) diff --git a/test/mod_wsgi_test/README.rst b/test/mod_wsgi_test/README.rst index 2c204f7ac..e96db9406 100644 --- a/test/mod_wsgi_test/README.rst +++ b/test/mod_wsgi_test/README.rst @@ -107,4 +107,4 @@ Automation At MongoDB, Inc. we use a continuous integration job that tests each combination in the matrix. The job starts up Apache, starts a single server or replica set, and runs ``test_client.py`` with the proper arguments. -See `run-mod-wsgi-tests.sh `_ +See `run-mod-wsgi-tests.sh `_ diff --git a/test/test_client_context.py b/test/test_client_context.py index 5996f9243..e807ac5f5 100644 --- a/test/test_client_context.py +++ b/test/test_client_context.py @@ -25,7 +25,7 @@ _IS_SYNC = True class TestClientContext(UnitTest): def test_must_connect(self): - if "PYMONGO_MUST_CONNECT" not in os.environ: + if not os.environ.get("PYMONGO_MUST_CONNECT"): raise SkipTest("PYMONGO_MUST_CONNECT is not set") self.assertTrue( @@ -37,7 +37,7 @@ class TestClientContext(UnitTest): ) def test_serverless(self): - if "TEST_SERVERLESS" not in os.environ: + if not os.environ.get("TEST_SERVERLESS"): raise SkipTest("TEST_SERVERLESS is not set") self.assertTrue( @@ -47,7 +47,7 @@ class TestClientContext(UnitTest): ) def test_enableTestCommands_is_disabled(self): - if "PYMONGO_DISABLE_TEST_COMMANDS" not in os.environ: + if not os.environ.get("PYMONGO_DISABLE_TEST_COMMANDS"): raise SkipTest("PYMONGO_DISABLE_TEST_COMMANDS is not set") self.assertFalse( @@ -56,7 +56,7 @@ class TestClientContext(UnitTest): ) def test_setdefaultencoding_worked(self): - if "SETDEFAULTENCODING" not in os.environ: + if not os.environ.get("SETDEFAULTENCODING"): raise SkipTest("SETDEFAULTENCODING is not set") self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"]) From 9b5c0981d91aa1baf23d6300bc033fd832457ae4 Mon Sep 17 00:00:00 2001 From: Jib Date: Mon, 25 Nov 2024 13:13:44 -0500 Subject: [PATCH 4/9] PYTHON-4988: Check C extensions are loaded ONLY in CPython builds (#2016) --- .evergreen/run-tests.sh | 4 ++-- .evergreen/utils.sh | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 9716c1fc7..95fe10a6c 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -38,6 +38,7 @@ export PIP_PREFER_BINARY=1 # Prefer binary dists by default set +x python -c "import sys; sys.exit(sys.prefix == sys.base_prefix)" || (echo "Not inside a virtual env!"; exit 1) +PYTHON_IMPL=$(python -c "import platform; print(platform.python_implementation())") # Try to source local Drivers Secrets if [ -f ./secrets-export.sh ]; then @@ -48,7 +49,7 @@ else fi # Ensure C extensions have compiled. -if [ -z "${NO_EXT:-}" ]; then +if [ -z "${NO_EXT:-}" ] && [ "$PYTHON_IMPL" = "CPython" ]; then python tools/fail_if_no_c.py fi @@ -245,7 +246,6 @@ python -c 'import sys; print(sys.version)' # Run the tests with coverage if requested and coverage is installed. # Only cover CPython. PyPy reports suspiciously low coverage. -PYTHON_IMPL=$(python -c "import platform; print(platform.python_implementation())") if [ -n "$COVERAGE" ] && [ "$PYTHON_IMPL" = "CPython" ]; then # Keep in sync with combine-coverage.sh. # coverage >=5 is needed for relative_files=true. diff --git a/.evergreen/utils.sh b/.evergreen/utils.sh index 908cf0564..d3af2dcc7 100755 --- a/.evergreen/utils.sh +++ b/.evergreen/utils.sh @@ -78,6 +78,7 @@ testinstall () { PYTHON=$1 RELEASE=$2 NO_VIRTUALENV=$3 + PYTHON_IMPL=$(python -c "import platform; print(platform.python_implementation())") if [ -z "$NO_VIRTUALENV" ]; then createvirtualenv $PYTHON venvtestinstall @@ -86,7 +87,11 @@ testinstall () { $PYTHON -m pip install --upgrade $RELEASE cd tools - $PYTHON fail_if_no_c.py + + if [ "$PYTHON_IMPL" = "CPython" ]; then + $PYTHON fail_if_no_c.py + fi + $PYTHON -m pip uninstall -y pymongo cd .. From 0e8d70457f2e1ac05baff0c9f5232ddee2b0abcf Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 26 Nov 2024 16:55:27 -0500 Subject: [PATCH 5/9] Async client uses tasks instead of threads PYTHON-4725 - Async client should use tasks for SDAM instead of threads PYTHON-4860 - Async client should use asyncio.Lock and asyncio.Condition PYTHON-4941 - Synchronous unified test runner being used in asynchronous tests PYTHON-4843 - Async test suite should use a single event loop PYTHON-4945 - Fix test cleanups for mongoses Co-authored-by: Iris <58442094+sleepyStick@users.noreply.github.com> --- .evergreen/config.yml | 2 +- THIRD-PARTY-NOTICES | 58 ++ pymongo/_asyncio_lock.py | 309 +++++++ pymongo/_asyncio_task.py | 49 + pymongo/asynchronous/client_bulk.py | 1 - pymongo/asynchronous/cursor.py | 4 +- pymongo/asynchronous/encryption.py | 7 + pymongo/asynchronous/mongo_client.py | 25 +- pymongo/asynchronous/monitor.py | 54 +- pymongo/asynchronous/periodic_executor.py | 219 ----- pymongo/asynchronous/pool.py | 28 +- pymongo/asynchronous/topology.py | 24 +- pymongo/lock.py | 245 +---- pymongo/network_layer.py | 10 +- .../{synchronous => }/periodic_executor.py | 113 ++- pymongo/synchronous/client_bulk.py | 1 - pymongo/synchronous/cursor.py | 2 +- pymongo/synchronous/encryption.py | 7 + pymongo/synchronous/mongo_client.py | 21 +- pymongo/synchronous/monitor.py | 18 +- pymongo/synchronous/pool.py | 28 +- pymongo/synchronous/topology.py | 20 +- test/__init__.py | 97 +- test/asynchronous/__init__.py | 101 +-- test/asynchronous/conftest.py | 2 +- test/asynchronous/test_bulk.py | 36 +- test/asynchronous/test_change_stream.py | 43 +- test/asynchronous/test_client.py | 80 +- test/asynchronous/test_collation.py | 31 +- test/asynchronous/test_collection.py | 41 +- ...nnections_survive_primary_stepdown_spec.py | 31 +- test/asynchronous/test_create_entities.py | 6 + test/asynchronous/test_cursor.py | 14 +- test/asynchronous/test_database.py | 3 +- test/asynchronous/test_encryption.py | 140 ++- test/asynchronous/test_grid_file.py | 1 + test/asynchronous/test_locks.py | 857 ++++++++---------- test/asynchronous/test_monitoring.py | 41 +- test/asynchronous/test_retryable_writes.py | 65 +- test/asynchronous/test_session.py | 33 +- test/asynchronous/test_transactions.py | 23 +- test/asynchronous/unified_format.py | 64 +- test/asynchronous/utils_spec_runner.py | 26 +- test/conftest.py | 2 +- test/test_bulk.py | 32 +- test/test_change_stream.py | 39 +- test/test_client.py | 31 +- test/test_collation.py | 29 +- test/test_collection.py | 38 +- ...nnections_survive_primary_stepdown_spec.py | 31 +- test/test_create_entities.py | 6 + test/test_cursor.py | 4 - test/test_custom_types.py | 23 +- test/test_database.py | 1 + test/test_encryption.py | 138 ++- test/test_examples.py | 13 +- test/test_grid_file.py | 1 + test/test_gridfs.py | 20 +- test/test_gridfs_bucket.py | 14 +- test/test_monitor.py | 2 +- test/test_monitoring.py | 39 +- test/test_read_concern.py | 20 +- test/test_retryable_writes.py | 65 +- test/test_sdam_monitoring_spec.py | 2 +- test/test_session.py | 32 +- test/test_threads.py | 1 + test/test_transactions.py | 15 +- test/test_typing.py | 7 +- test/unified_format.py | 57 +- test/utils.py | 11 +- test/utils_spec_runner.py | 26 +- tools/synchro.py | 39 +- 72 files changed, 1737 insertions(+), 1981 deletions(-) create mode 100644 pymongo/_asyncio_lock.py create mode 100644 pymongo/_asyncio_task.py delete mode 100644 pymongo/asynchronous/periodic_executor.py rename pymongo/{synchronous => }/periodic_executor.py (67%) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 59b8a543f..7ca3a72b1 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -281,7 +281,7 @@ functions: "run tests": - command: subprocess.exec params: - include_expansions_in_env: ["TEST_DATA_LAKE", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"] + include_expansions_in_env: ["TEST_DATA_LAKE", "PYTHON_BINARY", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"] binary: bash working_dir: "src" args: diff --git a/THIRD-PARTY-NOTICES b/THIRD-PARTY-NOTICES index 55b8ff707..ad00831a2 100644 --- a/THIRD-PARTY-NOTICES +++ b/THIRD-PARTY-NOTICES @@ -38,3 +38,61 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +2) License Notice for _asyncio_lock.py +----------------------------------------- + +1. This LICENSE AGREEMENT is between the Python Software Foundation +("PSF"), and the Individual or Organization ("Licensee") accessing and +otherwise using this software ("Python") in source or binary form and +its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, PSF hereby +grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +analyze, test, perform and/or display publicly, prepare derivative works, +distribute, and otherwise use Python alone or in any derivative version, +provided, however, that PSF's License Agreement and PSF's notice of copyright, +i.e., "Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved" +are retained in Python alone or in any derivative version prepared by Licensee. + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python. + +4. PSF is making Python available to Licensee on an "AS IS" +basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. Nothing in this License Agreement shall be deemed to create any +relationship of agency, partnership, or joint venture between PSF and +Licensee. This License Agreement does not grant permission to use PSF +trademarks or trade name in a trademark sense to endorse or promote +products or services of Licensee, or any third party. + +8. By copying, installing or otherwise using Python, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR +OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +PERFORMANCE OF THIS SOFTWARE. diff --git a/pymongo/_asyncio_lock.py b/pymongo/_asyncio_lock.py new file mode 100644 index 000000000..669b0f63a --- /dev/null +++ b/pymongo/_asyncio_lock.py @@ -0,0 +1,309 @@ +# Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved + +"""Lock and Condition classes vendored from https://github.com/python/cpython/blob/main/Lib/asyncio/locks.py +to port 3.13 fixes to older versions of Python. +Can be removed once we drop Python 3.12 support.""" + +from __future__ import annotations + +import collections +import threading +from asyncio import events, exceptions +from typing import Any, Coroutine, Optional + +_global_lock = threading.Lock() + + +class _LoopBoundMixin: + _loop = None + + def _get_loop(self) -> Any: + loop = events._get_running_loop() + + if self._loop is None: + with _global_lock: + if self._loop is None: + self._loop = loop + if loop is not self._loop: + raise RuntimeError(f"{self!r} is bound to a different event loop") + return loop + + +class _ContextManagerMixin: + async def __aenter__(self) -> None: + await self.acquire() # type: ignore[attr-defined] + # We have no use for the "as ..." clause in the with + # statement for locks. + return + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + self.release() # type: ignore[attr-defined] + + +class Lock(_ContextManagerMixin, _LoopBoundMixin): + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular task when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another task changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one task is blocked in acquire() waiting for + the state to turn to unlocked, only one task proceeds when a + release() call resets the state to unlocked; successive release() + calls will unblock tasks in FIFO order. + + Locks also support the asynchronous context management protocol. + 'async with lock' statement should be used. + + Usage: + + lock = Lock() + ... + await lock.acquire() + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + async with lock: + ... + + Lock objects can be tested for locking state: + + if not lock.locked(): + await lock.acquire() + else: + # lock is acquired + ... + + """ + + def __init__(self) -> None: + self._waiters: Optional[collections.deque] = None + self._locked = False + + def __repr__(self) -> str: + res = super().__repr__() + extra = "locked" if self._locked else "unlocked" + if self._waiters: + extra = f"{extra}, waiters:{len(self._waiters)}" + return f"<{res[1:-1]} [{extra}]>" + + def locked(self) -> bool: + """Return True if lock is acquired.""" + return self._locked + + async def acquire(self) -> bool: + """Acquire a lock. + + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + # Implement fair scheduling, where thread always waits + # its turn. Jumping the queue if all are cancelled is an optimization. + if not self._locked and ( + self._waiters is None or all(w.cancelled() for w in self._waiters) + ): + self._locked = True + return True + + if self._waiters is None: + self._waiters = collections.deque() + fut = self._get_loop().create_future() + self._waiters.append(fut) + + try: + try: + await fut + finally: + self._waiters.remove(fut) + except exceptions.CancelledError: + # Currently the only exception designed be able to occur here. + + # Ensure the lock invariant: If lock is not claimed (or about + # to be claimed by us) and there is a Task in waiters, + # ensure that the Task at the head will run. + if not self._locked: + self._wake_up_first() + raise + + # assert self._locked is False + self._locked = True + return True + + def release(self) -> None: + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other tasks are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + self._wake_up_first() + else: + raise RuntimeError("Lock is not acquired.") + + def _wake_up_first(self) -> None: + """Ensure that the first waiter will wake up.""" + if not self._waiters: + return + try: + fut = next(iter(self._waiters)) + except StopIteration: + return + + # .done() means that the waiter is already set to wake up. + if not fut.done(): + fut.set_result(True) + + +class Condition(_ContextManagerMixin, _LoopBoundMixin): + """Asynchronous equivalent to threading.Condition. + + This class implements condition variable objects. A condition variable + allows one or more tasks to wait until they are notified by another + task. + + A new Lock object is created and used as the underlying lock. + """ + + def __init__(self, lock: Optional[Lock] = None) -> None: + if lock is None: + lock = Lock() + + self._lock = lock + # Export the lock's locked(), acquire() and release() methods. + self.locked = lock.locked + self.acquire = lock.acquire + self.release = lock.release + + self._waiters: collections.deque = collections.deque() + + def __repr__(self) -> str: + res = super().__repr__() + extra = "locked" if self.locked() else "unlocked" + if self._waiters: + extra = f"{extra}, waiters:{len(self._waiters)}" + return f"<{res[1:-1]} [{extra}]>" + + async def wait(self) -> bool: + """Wait until notified. + + If the calling task has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another task. Once + awakened, it re-acquires the lock and returns True. + + This method may return spuriously, + which is why the caller should always + re-check the state and be prepared to wait() again. + """ + if not self.locked(): + raise RuntimeError("cannot wait on un-acquired lock") + + fut = self._get_loop().create_future() + self.release() + try: + try: + self._waiters.append(fut) + try: + await fut + return True + finally: + self._waiters.remove(fut) + + finally: + # Must re-acquire lock even if wait is cancelled. + # We only catch CancelledError here, since we don't want any + # other (fatal) errors with the future to cause us to spin. + err = None + while True: + try: + await self.acquire() + break + except exceptions.CancelledError as e: + err = e + + if err is not None: + try: + raise err # Re-raise most recent exception instance. + finally: + err = None # Break reference cycles. + except BaseException: + # Any error raised out of here _may_ have occurred after this Task + # believed to have been successfully notified. + # Make sure to notify another Task instead. This may result + # in a "spurious wakeup", which is allowed as part of the + # Condition Variable protocol. + self._notify(1) + raise + + async def wait_for(self, predicate: Any) -> Coroutine: + """Wait until a predicate becomes true. + + The predicate should be a callable whose result will be + interpreted as a boolean value. The method will repeatedly + wait() until it evaluates to true. The final predicate value is + the return value. + """ + result = predicate() + while not result: + await self.wait() + result = predicate() + return result + + def notify(self, n: int = 1) -> None: + """By default, wake up one task waiting on this condition, if any. + If the calling task has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up n of the tasks waiting for the condition + variable; if fewer than n are waiting, they are all awoken. + + Note: an awakened task does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self.locked(): + raise RuntimeError("cannot notify on un-acquired lock") + self._notify(n) + + def _notify(self, n: int) -> None: + idx = 0 + for fut in self._waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self) -> None: + """Wake up all tasks waiting on this condition. This method acts + like notify(), but wakes up all waiting tasks instead of one. If the + calling task has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._waiters)) diff --git a/pymongo/_asyncio_task.py b/pymongo/_asyncio_task.py new file mode 100644 index 000000000..8e457763d --- /dev/null +++ b/pymongo/_asyncio_task.py @@ -0,0 +1,49 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A custom asyncio.Task that allows checking if a task has been sent a cancellation request. +Can be removed once we drop Python 3.10 support in favor of asyncio.Task.cancelling.""" + + +from __future__ import annotations + +import asyncio +import sys +from typing import Any, Coroutine, Optional + + +# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered +class _Task(asyncio.Task): + def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None: + super().__init__(coro, name=name) + self._cancel_requests = 0 + asyncio._register_task(self) + + def cancel(self, msg: Optional[str] = None) -> bool: + self._cancel_requests += 1 + return super().cancel(msg=msg) + + def uncancel(self) -> int: + if self._cancel_requests > 0: + self._cancel_requests -= 1 + return self._cancel_requests + + def cancelling(self) -> int: + return self._cancel_requests + + +def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task: + if sys.version_info >= (3, 11): + return asyncio.create_task(coro, name=name) + return _Task(coro, name=name) diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 0dcdaa6c0..45824256d 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -476,7 +476,6 @@ class _AsyncClientBulk: if op_type == "delete": res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment] full_result[f"{op_type}Results"][original_index] = res - except Exception as exc: # Attempt to close the cursor, then raise top-level error. if cmd_cursor.alive: diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 4b4bb52a8..7d7ae4a5d 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -45,7 +45,7 @@ from pymongo.common import ( ) from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure -from pymongo.lock import _ALock, _create_lock +from pymongo.lock import _async_create_lock from pymongo.message import ( _CursorAddress, _GetMore, @@ -77,7 +77,7 @@ class _ConnectionManager: def __init__(self, conn: AsyncConnection, more_to_come: bool): self.conn: Optional[AsyncConnection] = conn self.more_to_come = more_to_come - self._alock = _ALock(_create_lock()) + self._lock = _async_create_lock() def update_exhaust(self, more_to_come: bool) -> None: self.more_to_come = more_to_come diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 735e54304..4802c3f54 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -15,6 +15,7 @@ """Support for explicit client-side field level encryption.""" from __future__ import annotations +import asyncio import contextlib import enum import socket @@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]: # BSON encoding/decoding errors are unrelated to encryption so # we should propagate them unchanged. raise + except asyncio.CancelledError: + raise except Exception as exc: raise EncryptionError(exc) from exc @@ -200,6 +203,8 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc] conn.close() except (PyMongoError, MongoCryptError): raise # Propagate pymongo errors directly. + except asyncio.CancelledError: + raise except Exception as error: # Wrap I/O errors in PyMongo exceptions. _raise_connection_failure((host, port), error) @@ -722,6 +727,8 @@ class AsyncClientEncryption(Generic[_DocumentType]): await database.create_collection(name=name, **kwargs), encrypted_fields, ) + except asyncio.CancelledError: + raise except Exception as exc: raise EncryptedCollectionError(exc, encrypted_fields) from exc diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 3e4dc482d..1600e5062 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -32,6 +32,7 @@ access: """ from __future__ import annotations +import asyncio import contextlib import os import warnings @@ -59,8 +60,8 @@ from typing import ( from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, common, helpers_shared, uri_parser -from pymongo.asynchronous import client_session, database, periodic_executor +from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser +from pymongo.asynchronous import client_session, database from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.asynchronous.client_session import _EmptyServerSession @@ -82,7 +83,11 @@ from pymongo.errors import ( WaitQueueTimeoutError, WriteConcernError, ) -from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks +from pymongo.lock import ( + _HAS_REGISTER_AT_FORK, + _async_create_lock, + _release_locks, +) from pymongo.logger import _CLIENT_LOGGER, _log_or_warn from pymongo.message import _CursorAddress, _GetMore, _Query from pymongo.monitoring import ConnectionClosedReason @@ -842,7 +847,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._default_database_name = dbase - self._lock = _ALock(_create_lock()) + self._lock = _async_create_lock() self._kill_cursors_queue: list = [] self._event_listeners = options.pool_options._event_listeners @@ -908,7 +913,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): await AsyncMongoClient._process_periodic_tasks(client) return True - executor = periodic_executor.PeriodicExecutor( + executor = periodic_executor.AsyncPeriodicExecutor( interval=common.KILL_CURSOR_FREQUENCY, min_interval=common.MIN_HEARTBEAT_INTERVAL, target=target, @@ -1722,7 +1727,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): address=address, ) - async with operation.conn_mgr._alock: + async with operation.conn_mgr._lock: async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) return await server.run_operation( @@ -1970,7 +1975,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): try: if conn_mgr: - async with conn_mgr._alock: + async with conn_mgr._lock: # Cursor is pinned to LB outside of a transaction. assert address is not None assert conn_mgr.conn is not None @@ -2033,6 +2038,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): for address, cursor_id, conn_mgr in pinned_cursors: try: await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it @@ -2047,6 +2054,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): for address, cursor_ids in address_to_cursor_ids.items(): try: await self._kill_cursors(cursor_ids, address, topology, session=None) + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: raise @@ -2061,6 +2070,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): try: await self._process_kill_cursors() await self._topology.update_pool() + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: return diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index a4dc9b7f4..ad1bc70ab 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -16,20 +16,20 @@ from __future__ import annotations +import asyncio import atexit import logging import time import weakref from typing import TYPE_CHECKING, Any, Mapping, Optional, cast -from pymongo import common +from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum -from pymongo.asynchronous import periodic_executor -from pymongo.asynchronous.periodic_executor import _shutdown_executors from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled from pymongo.hello import Hello -from pymongo.lock import _create_lock +from pymongo.lock import _async_create_lock from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage +from pymongo.periodic_executor import _shutdown_executors from pymongo.pool_options import _is_faas from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription @@ -76,7 +76,7 @@ class MonitorBase: await monitor._run() # type:ignore[attr-defined] return True - executor = periodic_executor.PeriodicExecutor( + executor = periodic_executor.AsyncPeriodicExecutor( interval=interval, min_interval=min_interval, target=target, name=name ) @@ -112,9 +112,9 @@ class MonitorBase: """ self.gc_safe_close() - def join(self, timeout: Optional[int] = None) -> None: + async def join(self, timeout: Optional[int] = None) -> None: """Wait for the monitor to stop.""" - self._executor.join(timeout) + await self._executor.join(timeout) def request_check(self) -> None: """If the monitor is sleeping, wake it soon.""" @@ -139,7 +139,7 @@ class Monitor(MonitorBase): """ super().__init__( topology, - "pymongo_server_monitor_thread", + "pymongo_server_monitor_task", topology_settings.heartbeat_frequency, common.MIN_HEARTBEAT_INTERVAL, ) @@ -238,6 +238,9 @@ class Monitor(MonitorBase): except ReferenceError: # Topology was garbage-collected. await self.close() + finally: + if self._executor._stopped: + await self._rtt_monitor.close() async def _check_server(self) -> ServerDescription: """Call hello or read the next streaming response. @@ -252,8 +255,10 @@ class Monitor(MonitorBase): except (OperationFailure, NotPrimaryError) as exc: # Update max cluster time even when hello fails. details = cast(Mapping[str, Any], exc.details) - self._topology.receive_cluster_time(details.get("$clusterTime")) + await self._topology.receive_cluster_time(details.get("$clusterTime")) raise + except asyncio.CancelledError: + raise except ReferenceError: raise except Exception as error: @@ -280,7 +285,7 @@ class Monitor(MonitorBase): await self._reset_connection() if isinstance(error, _OperationCancelled): raise - self._rtt_monitor.reset() + await self._rtt_monitor.reset() # Server type defaults to Unknown. return ServerDescription(address, error=error) @@ -321,9 +326,9 @@ class Monitor(MonitorBase): self._conn_id = conn.id response, round_trip_time = await self._check_with_socket(conn) if not response.awaitable: - self._rtt_monitor.add_sample(round_trip_time) + await self._rtt_monitor.add_sample(round_trip_time) - avg_rtt, min_rtt = self._rtt_monitor.get() + avg_rtt, min_rtt = await self._rtt_monitor.get() sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt) if self._publish: assert self._listeners is not None @@ -419,6 +424,8 @@ class SrvMonitor(MonitorBase): if len(seedlist) == 0: # As per the spec: this should be treated as a failure. raise Exception + except asyncio.CancelledError: + raise except Exception: # As per the spec, upon encountering an error: # - An error must not be raised @@ -439,7 +446,7 @@ class _RttMonitor(MonitorBase): """ super().__init__( topology, - "pymongo_server_rtt_thread", + "pymongo_server_rtt_task", topology_settings.heartbeat_frequency, common.MIN_HEARTBEAT_INTERVAL, ) @@ -447,7 +454,7 @@ class _RttMonitor(MonitorBase): self._pool = pool self._moving_average = MovingAverage() self._moving_min = MovingMinimum() - self._lock = _create_lock() + self._lock = _async_create_lock() async def close(self) -> None: self.gc_safe_close() @@ -455,20 +462,20 @@ class _RttMonitor(MonitorBase): # thread has the socket checked out, it will be closed when checked in. await self._pool.reset() - def add_sample(self, sample: float) -> None: + async def add_sample(self, sample: float) -> None: """Add a RTT sample.""" - with self._lock: + async with self._lock: self._moving_average.add_sample(sample) self._moving_min.add_sample(sample) - def get(self) -> tuple[Optional[float], float]: + async def get(self) -> tuple[Optional[float], float]: """Get the calculated average, or None if no samples yet and the min.""" - with self._lock: + async with self._lock: return self._moving_average.get(), self._moving_min.get() - def reset(self) -> None: + async def reset(self) -> None: """Reset the average RTT.""" - with self._lock: + async with self._lock: self._moving_average.reset() self._moving_min.reset() @@ -478,10 +485,12 @@ class _RttMonitor(MonitorBase): # heartbeat protocol (MongoDB 4.4+). # XXX: Skip check if the server is unknown? rtt = await self._ping() - self.add_sample(rtt) + await self.add_sample(rtt) except ReferenceError: # Topology was garbage-collected. await self.close() + except asyncio.CancelledError: + raise except Exception: await self._pool.reset() @@ -536,4 +545,5 @@ def _shutdown_resources() -> None: shutdown() -atexit.register(_shutdown_resources) +if _IS_SYNC: + atexit.register(_shutdown_resources) diff --git a/pymongo/asynchronous/periodic_executor.py b/pymongo/asynchronous/periodic_executor.py deleted file mode 100644 index f3d2fddba..000000000 --- a/pymongo/asynchronous/periodic_executor.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Run a target function on a background thread.""" - -from __future__ import annotations - -import asyncio -import sys -import threading -import time -import weakref -from typing import Any, Optional - -from pymongo.lock import _ALock, _create_lock - -_IS_SYNC = False - - -class PeriodicExecutor: - def __init__( - self, - interval: float, - min_interval: float, - target: Any, - name: Optional[str] = None, - ): - """Run a target function periodically on a background thread. - - If the target's return value is false, the executor stops. - - :param interval: Seconds between calls to `target`. - :param min_interval: Minimum seconds between calls if `wake` is - called very often. - :param target: A function. - :param name: A name to give the underlying thread. - """ - # threading.Event and its internal condition variable are expensive - # in Python 2, see PYTHON-983. Use a boolean to know when to wake. - # The executor's design is constrained by several Python issues, see - # "periodic_executor.rst" in this repository. - self._event = False - self._interval = interval - self._min_interval = min_interval - self._target = target - self._stopped = False - self._thread: Optional[threading.Thread] = None - self._name = name - self._skip_sleep = False - self._thread_will_exit = False - self._lock = _ALock(_create_lock()) - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>" - - def _run_async(self) -> None: - # The default asyncio loop implementation on Windows - # has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240) - # We explicitly use a different loop implementation here to prevent that issue - if sys.platform == "win32": - loop = asyncio.SelectorEventLoop() - try: - loop.run_until_complete(self._run()) # type: ignore[func-returns-value] - finally: - loop.close() - else: - asyncio.run(self._run()) # type: ignore[func-returns-value] - - def open(self) -> None: - """Start. Multiple calls have no effect. - - Not safe to call from multiple threads at once. - """ - with self._lock: - if self._thread_will_exit: - # If the background thread has read self._stopped as True - # there is a chance that it has not yet exited. The call to - # join should not block indefinitely because there is no - # other work done outside the while loop in self._run. - try: - assert self._thread is not None - self._thread.join() - except ReferenceError: - # Thread terminated. - pass - self._thread_will_exit = False - self._stopped = False - started: Any = False - try: - started = self._thread and self._thread.is_alive() - except ReferenceError: - # Thread terminated. - pass - - if not started: - if _IS_SYNC: - thread = threading.Thread(target=self._run, name=self._name) - else: - thread = threading.Thread(target=self._run_async, name=self._name) - thread.daemon = True - self._thread = weakref.proxy(thread) - _register_executor(self) - # Mitigation to RuntimeError firing when thread starts on shutdown - # https://github.com/python/cpython/issues/114570 - try: - thread.start() - except RuntimeError as e: - if "interpreter shutdown" in str(e) or sys.is_finalizing(): - self._thread = None - return - raise - - def close(self, dummy: Any = None) -> None: - """Stop. To restart, call open(). - - The dummy parameter allows an executor's close method to be a weakref - callback; see monitor.py. - """ - self._stopped = True - - def join(self, timeout: Optional[int] = None) -> None: - if self._thread is not None: - try: - self._thread.join(timeout) - except (ReferenceError, RuntimeError): - # Thread already terminated, or not yet started. - pass - - def wake(self) -> None: - """Execute the target function soon.""" - self._event = True - - def update_interval(self, new_interval: int) -> None: - self._interval = new_interval - - def skip_sleep(self) -> None: - self._skip_sleep = True - - async def _should_stop(self) -> bool: - async with self._lock: - if self._stopped: - self._thread_will_exit = True - return True - return False - - async def _run(self) -> None: - while not await self._should_stop(): - try: - if not await self._target(): - self._stopped = True - break - except BaseException: - async with self._lock: - self._stopped = True - self._thread_will_exit = True - - raise - - if self._skip_sleep: - self._skip_sleep = False - else: - deadline = time.monotonic() + self._interval - while not self._stopped and time.monotonic() < deadline: - await asyncio.sleep(self._min_interval) - if self._event: - break # Early wake. - - self._event = False - - -# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started, -# an executor is kept alive by a strong reference from its thread and perhaps -# from other objects. When the thread dies and all other referrers are freed, -# the executor is freed and removed from _EXECUTORS. If any threads are -# running when the interpreter begins to shut down, we try to halt and join -# them to avoid spurious errors. -_EXECUTORS = set() - - -def _register_executor(executor: PeriodicExecutor) -> None: - ref = weakref.ref(executor, _on_executor_deleted) - _EXECUTORS.add(ref) - - -def _on_executor_deleted(ref: weakref.ReferenceType[PeriodicExecutor]) -> None: - _EXECUTORS.remove(ref) - - -def _shutdown_executors() -> None: - if _EXECUTORS is None: - return - - # Copy the set. Stopping threads has the side effect of removing executors. - executors = list(_EXECUTORS) - - # First signal all executors to close... - for ref in executors: - executor = ref() - if executor: - executor.close() - - # ...then try to join them. - for ref in executors: - executor = ref() - if executor: - executor.join(1) - - executor = None diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index ca0cebd41..5dc5675a0 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -23,7 +23,6 @@ import os import socket import ssl import sys -import threading import time import weakref from typing import ( @@ -65,7 +64,11 @@ from pymongo.errors import ( # type:ignore[attr-defined] _CertificateError, ) from pymongo.hello import Hello, HelloCompat -from pymongo.lock import _ACondition, _ALock, _create_lock +from pymongo.lock import ( + _async_cond_wait, + _async_create_condition, + _async_create_lock, +) from pymongo.logger import ( _CONNECTION_LOGGER, _ConnectionStatusMessage, @@ -208,11 +211,6 @@ def _raise_connection_failure( raise AutoReconnect(msg) from error -async def _cond_wait(condition: _ACondition, deadline: Optional[float]) -> bool: - timeout = deadline - time.monotonic() if deadline else None - return await condition.wait(timeout) - - def _get_timeout_details(options: PoolOptions) -> dict[str, float]: details = {} timeout = _csot.get_timeout() @@ -706,6 +704,8 @@ class AsyncConnection: # shutdown. try: self.conn.close() + except asyncio.CancelledError: + raise except Exception: # noqa: S110 pass @@ -992,8 +992,8 @@ class Pool: # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - _lock = _create_lock() - self.lock = _ALock(_lock) + self.lock = _async_create_lock() + self._max_connecting_cond = _async_create_condition(self.lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1019,7 +1019,7 @@ class Pool: # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = _ACondition(threading.Condition(_lock)) + self.size_cond = _async_create_condition(self.lock) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1027,7 +1027,7 @@ class Pool: # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = _ACondition(threading.Condition(_lock)) + self._max_connecting_cond = _async_create_condition(self.lock) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id @@ -1466,7 +1466,8 @@ class Pool: async with self.size_cond: self._raise_if_not_ready(checkout_started_time, emit_event=True) while not (self.requests < self.max_pool_size): - if not await _cond_wait(self.size_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not await _async_cond_wait(self.size_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.requests < self.max_pool_size: @@ -1489,7 +1490,8 @@ class Pool: async with self._max_connecting_cond: self._raise_if_not_ready(checkout_started_time, emit_event=False) while not (self.conns or self._pending < self._max_connecting): - if not await _cond_wait(self._max_connecting_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not await _async_cond_wait(self._max_connecting_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.conns or self._pending < self._max_connecting: diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 82af4257b..6d67710a7 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -27,8 +27,7 @@ import weakref from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast -from pymongo import _csot, common, helpers_shared -from pymongo.asynchronous import periodic_executor +from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool from pymongo.asynchronous.monitor import SrvMonitor from pymongo.asynchronous.pool import Pool @@ -44,7 +43,11 @@ from pymongo.errors import ( WriteError, ) from pymongo.hello import Hello -from pymongo.lock import _ACondition, _ALock, _create_lock +from pymongo.lock import ( + _async_cond_wait, + _async_create_condition, + _async_create_lock, +) from pymongo.logger import ( _SDAM_LOGGER, _SERVER_SELECTION_LOGGER, @@ -170,9 +173,10 @@ class Topology: self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - _lock = _create_lock() - self._lock = _ALock(_lock) - self._condition = _ACondition(self._settings.condition_class(_lock)) + self._lock = _async_create_lock() + self._condition = _async_create_condition( + self._lock, self._settings.condition_class if _IS_SYNC else None + ) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None @@ -185,7 +189,7 @@ class Topology: async def target() -> bool: return process_events_queue(weak) - executor = periodic_executor.PeriodicExecutor( + executor = periodic_executor.AsyncPeriodicExecutor( interval=common.EVENTS_QUEUE_FREQUENCY, min_interval=common.MIN_HEARTBEAT_INTERVAL, target=target, @@ -354,7 +358,7 @@ class Topology: # change, or for a timeout. We won't miss any changes that # came after our most recent apply_selector call, since we've # held the lock until now. - await self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) + await _async_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL) self._description.check_compatible() now = time.monotonic() server_descriptions = self._description.apply_selector( @@ -654,7 +658,7 @@ class Topology: """Wake all monitors, wait for at least one to check its server.""" async with self._lock: self._request_check_all() - await self._condition.wait(wait_time) + await _async_cond_wait(self._condition, wait_time) def data_bearing_servers(self) -> list[ServerDescription]: """Return a list of all data-bearing servers. @@ -742,7 +746,7 @@ class Topology: if self._publish_server or self._publish_tp: # Make sure the events executor thread is fully closed before publishing the remaining events self.__events_executor.close() - self.__events_executor.join(1) + await self.__events_executor.join(1) process_events_queue(weakref.ref(self._events)) # type: ignore[arg-type] @property diff --git a/pymongo/lock.py b/pymongo/lock.py index 0cbfb4a57..6bf713801 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -11,15 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Internal helpers for lock and condition coordination primitives.""" + from __future__ import annotations import asyncio -import collections import os +import sys import threading -import time import weakref -from typing import Any, Callable, Optional, TypeVar +from asyncio import wait_for +from typing import Any, Optional, TypeVar + +import pymongo._asyncio_lock _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") @@ -28,6 +33,15 @@ _forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet() _T = TypeVar("_T") +# Needed to support 3.13 asyncio fixes (https://github.com/python/cpython/issues/112202) +# in older versions of Python +if sys.version_info >= (3, 13): + Lock = asyncio.Lock + Condition = asyncio.Condition +else: + Lock = pymongo._asyncio_lock.Lock + Condition = pymongo._asyncio_lock.Condition + def _create_lock() -> threading.Lock: """Represents a lock that is tracked upon instantiation using a WeakSet and @@ -39,6 +53,27 @@ def _create_lock() -> threading.Lock: return lock +def _async_create_lock() -> Lock: + """Represents an asyncio.Lock.""" + return Lock() + + +def _create_condition( + lock: threading.Lock, condition_class: Optional[Any] = None +) -> threading.Condition: + """Represents a threading.Condition.""" + if condition_class: + return condition_class(lock) + return threading.Condition(lock) + + +def _async_create_condition(lock: Lock, condition_class: Optional[Any] = None) -> Condition: + """Represents an asyncio.Condition.""" + if condition_class: + return condition_class(lock) + return Condition(lock) + + def _release_locks() -> None: # Completed the fork, reset all the locks in the child. for lock in _forkable_locks: @@ -46,202 +81,12 @@ def _release_locks() -> None: lock.release() -# Needed only for synchro.py compat. -def _Lock(lock: threading.Lock) -> threading.Lock: - return lock +async def _async_cond_wait(condition: Condition, timeout: Optional[float]) -> bool: + try: + return await wait_for(condition.wait(), timeout) + except asyncio.TimeoutError: + return False -class _ALock: - __slots__ = ("_lock",) - - def __init__(self, lock: threading.Lock) -> None: - self._lock = lock - - def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: - return self._lock.acquire(blocking=blocking, timeout=timeout) - - async def a_acquire(self, blocking: bool = True, timeout: float = -1) -> bool: - if timeout > 0: - tstart = time.monotonic() - while True: - acquired = self._lock.acquire(blocking=False) - if acquired: - return True - if timeout > 0 and (time.monotonic() - tstart) > timeout: - return False - if not blocking: - return False - await asyncio.sleep(0) - - def release(self) -> None: - self._lock.release() - - async def __aenter__(self) -> _ALock: - await self.a_acquire() - return self - - def __enter__(self) -> _ALock: - self._lock.acquire() - return self - - def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() - - async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() - - -def _safe_set_result(fut: asyncio.Future) -> None: - # Ensure the future hasn't been cancelled before calling set_result. - if not fut.done(): - fut.set_result(False) - - -class _ACondition: - __slots__ = ("_condition", "_waiters") - - def __init__(self, condition: threading.Condition) -> None: - self._condition = condition - self._waiters: collections.deque = collections.deque() - - async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: - if timeout > 0: - tstart = time.monotonic() - while True: - acquired = self._condition.acquire(blocking=False) - if acquired: - return True - if timeout > 0 and (time.monotonic() - tstart) > timeout: - return False - if not blocking: - return False - await asyncio.sleep(0) - - async def wait(self, timeout: Optional[float] = None) -> bool: - """Wait until notified. - - If the calling task has not acquired the lock when this - method is called, a RuntimeError is raised. - - This method releases the underlying lock, and then blocks - until it is awakened by a notify() or notify_all() call for - the same condition variable in another task. Once - awakened, it re-acquires the lock and returns True. - - This method may return spuriously, - which is why the caller should always - re-check the state and be prepared to wait() again. - """ - loop = asyncio.get_running_loop() - fut = loop.create_future() - self._waiters.append((loop, fut)) - self.release() - try: - try: - try: - await asyncio.wait_for(fut, timeout) - return True - except asyncio.TimeoutError: - return False # Return false on timeout for sync pool compat. - finally: - # Must re-acquire lock even if wait is cancelled. - # We only catch CancelledError here, since we don't want any - # other (fatal) errors with the future to cause us to spin. - err = None - while True: - try: - await self.acquire() - break - except asyncio.exceptions.CancelledError as e: - err = e - - self._waiters.remove((loop, fut)) - if err is not None: - try: - raise err # Re-raise most recent exception instance. - finally: - err = None # Break reference cycles. - except BaseException: - # Any error raised out of here _may_ have occurred after this Task - # believed to have been successfully notified. - # Make sure to notify another Task instead. This may result - # in a "spurious wakeup", which is allowed as part of the - # Condition Variable protocol. - self.notify(1) - raise - - async def wait_for(self, predicate: Callable[[], _T]) -> _T: - """Wait until a predicate becomes true. - - The predicate should be a callable whose result will be - interpreted as a boolean value. The method will repeatedly - wait() until it evaluates to true. The final predicate value is - the return value. - """ - result = predicate() - while not result: - await self.wait() - result = predicate() - return result - - def notify(self, n: int = 1) -> None: - """By default, wake up one coroutine waiting on this condition, if any. - If the calling coroutine has not acquired the lock when this method - is called, a RuntimeError is raised. - - This method wakes up at most n of the coroutines waiting for the - condition variable; it is a no-op if no coroutines are waiting. - - Note: an awakened coroutine does not actually return from its - wait() call until it can reacquire the lock. Since notify() does - not release the lock, its caller should. - """ - idx = 0 - to_remove = [] - for loop, fut in self._waiters: - if idx >= n: - break - - if fut.done(): - continue - - try: - loop.call_soon_threadsafe(_safe_set_result, fut) - except RuntimeError: - # Loop was closed, ignore. - to_remove.append((loop, fut)) - continue - - idx += 1 - - for waiter in to_remove: - self._waiters.remove(waiter) - - def notify_all(self) -> None: - """Wake up all threads waiting on this condition. This method acts - like notify(), but wakes up all waiting threads instead of one. If the - calling thread has not acquired the lock when this method is called, - a RuntimeError is raised. - """ - self.notify(len(self._waiters)) - - def locked(self) -> bool: - """Only needed for tests in test_locks.""" - return self._condition._lock.locked() # type: ignore[attr-defined] - - def release(self) -> None: - self._condition.release() - - async def __aenter__(self) -> _ACondition: - await self.acquire() - return self - - def __enter__(self) -> _ACondition: - self._condition.acquire() - return self - - async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() - - def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() +def _cond_wait(condition: threading.Condition, timeout: Optional[float]) -> bool: + return condition.wait(timeout) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index aa16e85a0..6ab6db2f7 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -29,6 +29,7 @@ from typing import ( ) from pymongo import _csot, ssl_support +from pymongo._asyncio_task import create_task from pymongo.errors import _OperationCancelled from pymongo.socket_checker import _errno_from_exception @@ -259,19 +260,20 @@ async def async_receive_data( sock.settimeout(0.0) loop = asyncio.get_event_loop() - cancellation_task = asyncio.create_task(_poll_cancellation(conn)) + cancellation_task = create_task(_poll_cancellation(conn)) try: if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): - read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] + read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] else: - read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] + read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] tasks = [read_task, cancellation_task] done, pending = await asyncio.wait( tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED ) for task in pending: task.cancel() - await asyncio.wait(pending) + if pending: + await asyncio.wait(pending) if len(done) == 0: raise socket.timeout("timed out") if read_task in done: diff --git a/pymongo/synchronous/periodic_executor.py b/pymongo/periodic_executor.py similarity index 67% rename from pymongo/synchronous/periodic_executor.py rename to pymongo/periodic_executor.py index 525268b14..2f89b91de 100644 --- a/pymongo/synchronous/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -23,9 +23,102 @@ import time import weakref from typing import Any, Optional +from pymongo._asyncio_task import create_task from pymongo.lock import _create_lock -_IS_SYNC = True +_IS_SYNC = False + + +class AsyncPeriodicExecutor: + def __init__( + self, + interval: float, + min_interval: float, + target: Any, + name: Optional[str] = None, + ): + """Run a target function periodically on a background task. + + If the target's return value is false, the executor stops. + + :param interval: Seconds between calls to `target`. + :param min_interval: Minimum seconds between calls if `wake` is + called very often. + :param target: A function. + :param name: A name to give the underlying task. + """ + self._event = False + self._interval = interval + self._min_interval = min_interval + self._target = target + self._stopped = False + self._task: Optional[asyncio.Task] = None + self._name = name + self._skip_sleep = False + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>" + + def open(self) -> None: + """Start. Multiple calls have no effect.""" + self._stopped = False + + if self._task is None or ( + self._task.done() and not self._task.cancelled() and not self._task.cancelling() # type: ignore[unused-ignore, attr-defined] + ): + self._task = create_task(self._run(), name=self._name) + + def close(self, dummy: Any = None) -> None: + """Stop. To restart, call open(). + + The dummy parameter allows an executor's close method to be a weakref + callback; see monitor.py. + """ + self._stopped = True + + async def join(self, timeout: Optional[int] = None) -> None: + if self._task is not None: + try: + await asyncio.wait_for(self._task, timeout=timeout) # type-ignore: [arg-type] + except asyncio.TimeoutError: + # Task timed out + pass + except asyncio.exceptions.CancelledError: + # Task was already finished, or not yet started. + raise + + def wake(self) -> None: + """Execute the target function soon.""" + self._event = True + + def update_interval(self, new_interval: int) -> None: + self._interval = new_interval + + def skip_sleep(self) -> None: + self._skip_sleep = True + + async def _run(self) -> None: + while not self._stopped: + if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined] + raise asyncio.CancelledError + try: + if not await self._target(): + self._stopped = True + break + except BaseException: + self._stopped = True + raise + + if self._skip_sleep: + self._skip_sleep = False + else: + deadline = time.monotonic() + self._interval + while not self._stopped and time.monotonic() < deadline: + await asyncio.sleep(self._min_interval) + if self._event: + break # Early wake. + + self._event = False class PeriodicExecutor: @@ -64,19 +157,6 @@ class PeriodicExecutor: def __repr__(self) -> str: return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>" - def _run_async(self) -> None: - # The default asyncio loop implementation on Windows - # has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240) - # We explicitly use a different loop implementation here to prevent that issue - if sys.platform == "win32": - loop = asyncio.SelectorEventLoop() - try: - loop.run_until_complete(self._run()) # type: ignore[func-returns-value] - finally: - loop.close() - else: - asyncio.run(self._run()) # type: ignore[func-returns-value] - def open(self) -> None: """Start. Multiple calls have no effect. @@ -104,10 +184,7 @@ class PeriodicExecutor: pass if not started: - if _IS_SYNC: - thread = threading.Thread(target=self._run, name=self._name) - else: - thread = threading.Thread(target=self._run_async, name=self._name) + thread = threading.Thread(target=self._run, name=self._name) thread.daemon = True self._thread = weakref.proxy(thread) _register_executor(self) diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 625e8429e..9f6e3f7cf 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -474,7 +474,6 @@ class _ClientBulk: if op_type == "delete": res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment] full_result[f"{op_type}Results"][original_index] = res - except Exception as exc: # Attempt to close the cursor, then raise top-level error. if cmd_cursor.alive: diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 27a76cf91..9a7637704 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -77,7 +77,7 @@ class _ConnectionManager: def __init__(self, conn: Connection, more_to_come: bool): self.conn: Optional[Connection] = conn self.more_to_come = more_to_come - self._alock = _create_lock() + self._lock = _create_lock() def update_exhaust(self, more_to_come: bool) -> None: self.more_to_come = more_to_come diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 506ff8bcb..09d0c0f2f 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -15,6 +15,7 @@ """Support for explicit client-side field level encryption.""" from __future__ import annotations +import asyncio import contextlib import enum import socket @@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]: # BSON encoding/decoding errors are unrelated to encryption so # we should propagate them unchanged. raise + except asyncio.CancelledError: + raise except Exception as exc: raise EncryptionError(exc) from exc @@ -200,6 +203,8 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc] conn.close() except (PyMongoError, MongoCryptError): raise # Propagate pymongo errors directly. + except asyncio.CancelledError: + raise except Exception as error: # Wrap I/O errors in PyMongo exceptions. _raise_connection_failure((host, port), error) @@ -716,6 +721,8 @@ class ClientEncryption(Generic[_DocumentType]): database.create_collection(name=name, **kwargs), encrypted_fields, ) + except asyncio.CancelledError: + raise except Exception as exc: raise EncryptedCollectionError(exc, encrypted_fields) from exc diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 00c6203a9..a694a58c1 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -32,6 +32,7 @@ access: """ from __future__ import annotations +import asyncio import contextlib import os import warnings @@ -58,7 +59,7 @@ from typing import ( from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, common, helpers_shared, uri_parser +from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser from pymongo.client_options import ClientOptions from pymongo.errors import ( AutoReconnect, @@ -74,7 +75,11 @@ from pymongo.errors import ( WaitQueueTimeoutError, WriteConcernError, ) -from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks +from pymongo.lock import ( + _HAS_REGISTER_AT_FORK, + _create_lock, + _release_locks, +) from pymongo.logger import _CLIENT_LOGGER, _log_or_warn from pymongo.message import _CursorAddress, _GetMore, _Query from pymongo.monitoring import ConnectionClosedReason @@ -91,7 +96,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.results import ClientBulkWriteResult from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import client_session, database, periodic_executor +from pymongo.synchronous import client_session, database from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.synchronous.client_session import _EmptyServerSession @@ -1716,7 +1721,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): address=address, ) - with operation.conn_mgr._alock: + with operation.conn_mgr._lock: with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) return server.run_operation( @@ -1964,7 +1969,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): try: if conn_mgr: - with conn_mgr._alock: + with conn_mgr._lock: # Cursor is pinned to LB outside of a transaction. assert address is not None assert conn_mgr.conn is not None @@ -2027,6 +2032,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): for address, cursor_id, conn_mgr in pinned_cursors: try: self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it @@ -2041,6 +2048,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): for address, cursor_ids in address_to_cursor_ids.items(): try: self._kill_cursors(cursor_ids, address, topology, session=None) + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: raise @@ -2055,6 +2064,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): try: self._process_kill_cursors() self._topology.update_pool() + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: return diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index d02ad0a6f..df4130d4a 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -16,24 +16,24 @@ from __future__ import annotations +import asyncio import atexit import logging import time import weakref from typing import TYPE_CHECKING, Any, Mapping, Optional, cast -from pymongo import common +from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled from pymongo.hello import Hello from pymongo.lock import _create_lock from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage +from pymongo.periodic_executor import _shutdown_executors from pymongo.pool_options import _is_faas from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription from pymongo.srv_resolver import _SrvResolver -from pymongo.synchronous import periodic_executor -from pymongo.synchronous.periodic_executor import _shutdown_executors if TYPE_CHECKING: from pymongo.synchronous.pool import Connection, Pool, _CancellationContext @@ -238,6 +238,9 @@ class Monitor(MonitorBase): except ReferenceError: # Topology was garbage-collected. self.close() + finally: + if self._executor._stopped: + self._rtt_monitor.close() def _check_server(self) -> ServerDescription: """Call hello or read the next streaming response. @@ -254,6 +257,8 @@ class Monitor(MonitorBase): details = cast(Mapping[str, Any], exc.details) self._topology.receive_cluster_time(details.get("$clusterTime")) raise + except asyncio.CancelledError: + raise except ReferenceError: raise except Exception as error: @@ -419,6 +424,8 @@ class SrvMonitor(MonitorBase): if len(seedlist) == 0: # As per the spec: this should be treated as a failure. raise Exception + except asyncio.CancelledError: + raise except Exception: # As per the spec, upon encountering an error: # - An error must not be raised @@ -482,6 +489,8 @@ class _RttMonitor(MonitorBase): except ReferenceError: # Topology was garbage-collected. self.close() + except asyncio.CancelledError: + raise except Exception: self._pool.reset() @@ -536,4 +545,5 @@ def _shutdown_resources() -> None: shutdown() -atexit.register(_shutdown_resources) +if _IS_SYNC: + atexit.register(_shutdown_resources) diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 86baf15b9..1a155c82d 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -23,7 +23,6 @@ import os import socket import ssl import sys -import threading import time import weakref from typing import ( @@ -62,7 +61,11 @@ from pymongo.errors import ( # type:ignore[attr-defined] _CertificateError, ) from pymongo.hello import Hello, HelloCompat -from pymongo.lock import _create_lock, _Lock +from pymongo.lock import ( + _cond_wait, + _create_condition, + _create_lock, +) from pymongo.logger import ( _CONNECTION_LOGGER, _ConnectionStatusMessage, @@ -208,11 +211,6 @@ def _raise_connection_failure( raise AutoReconnect(msg) from error -def _cond_wait(condition: threading.Condition, deadline: Optional[float]) -> bool: - timeout = deadline - time.monotonic() if deadline else None - return condition.wait(timeout) - - def _get_timeout_details(options: PoolOptions) -> dict[str, float]: details = {} timeout = _csot.get_timeout() @@ -704,6 +702,8 @@ class Connection: # shutdown. try: self.conn.close() + except asyncio.CancelledError: + raise except Exception: # noqa: S110 pass @@ -988,8 +988,8 @@ class Pool: # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - _lock = _create_lock() - self.lock = _Lock(_lock) + self.lock = _create_lock() + self._max_connecting_cond = _create_condition(self.lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1015,7 +1015,7 @@ class Pool: # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = threading.Condition(_lock) + self.size_cond = _create_condition(self.lock) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1023,7 +1023,7 @@ class Pool: # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = threading.Condition(_lock) + self._max_connecting_cond = _create_condition(self.lock) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id @@ -1460,7 +1460,8 @@ class Pool: with self.size_cond: self._raise_if_not_ready(checkout_started_time, emit_event=True) while not (self.requests < self.max_pool_size): - if not _cond_wait(self.size_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not _cond_wait(self.size_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.requests < self.max_pool_size: @@ -1483,7 +1484,8 @@ class Pool: with self._max_connecting_cond: self._raise_if_not_ready(checkout_started_time, emit_event=False) while not (self.conns or self._pending < self._max_connecting): - if not _cond_wait(self._max_connecting_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not _cond_wait(self._max_connecting_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.conns or self._pending < self._max_connecting: diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index a350c1702..b03269ae4 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -27,7 +27,7 @@ import weakref from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast -from pymongo import _csot, common, helpers_shared +from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.errors import ( ConnectionFailure, InvalidOperation, @@ -39,7 +39,11 @@ from pymongo.errors import ( WriteError, ) from pymongo.hello import Hello -from pymongo.lock import _create_lock, _Lock +from pymongo.lock import ( + _cond_wait, + _create_condition, + _create_lock, +) from pymongo.logger import ( _SDAM_LOGGER, _SERVER_SELECTION_LOGGER, @@ -56,7 +60,6 @@ from pymongo.server_selectors import ( secondary_server_selector, writable_server_selector, ) -from pymongo.synchronous import periodic_executor from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool from pymongo.synchronous.monitor import SrvMonitor from pymongo.synchronous.pool import Pool @@ -170,9 +173,10 @@ class Topology: self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - _lock = _create_lock() - self._lock = _Lock(_lock) - self._condition = self._settings.condition_class(_lock) + self._lock = _create_lock() + self._condition = _create_condition( + self._lock, self._settings.condition_class if _IS_SYNC else None + ) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None @@ -354,7 +358,7 @@ class Topology: # change, or for a timeout. We won't miss any changes that # came after our most recent apply_selector call, since we've # held the lock until now. - self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) + _cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL) self._description.check_compatible() now = time.monotonic() server_descriptions = self._description.apply_selector( @@ -652,7 +656,7 @@ class Topology: """Wake all monitors, wait for at least one to check its server.""" with self._lock: self._request_check_all() - self._condition.wait(wait_time) + _cond_wait(self._condition, wait_time) def data_bearing_servers(self) -> list[ServerDescription]: """Return a list of all data-bearing servers. diff --git a/test/__init__.py b/test/__init__.py index fd33fde29..d3a63db2d 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio import gc +import logging import multiprocessing import os import signal @@ -25,6 +26,7 @@ import subprocess import sys import threading import time +import traceback import unittest import warnings from asyncio import iscoroutinefunction @@ -191,6 +193,8 @@ class ClientContext: client.close() def _init_client(self): + self.mongoses = [] + self.connection_attempts = [] self.client = self._connect(host, port) if self.client is not None: # Return early when connected to dataLake as mongohoused does not @@ -860,6 +864,16 @@ class ClientContext: client_context = ClientContext() +def reset_client_context(): + if _IS_SYNC: + # sync tests don't need to reset a client context + return + elif client_context.client is not None: + client_context.client.close() + client_context.client = None + client_context._init_client() + + class PyMongoTestCase(unittest.TestCase): def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) @@ -1106,26 +1120,10 @@ class PyMongoTestCase(unittest.TestCase): class UnitTest(PyMongoTestCase): """Async base class for TestCases that don't require a connection to MongoDB.""" - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod - def _setup_class(cls): + def setUp(self) -> None: pass - @classmethod - def _tearDown_class(cls): + def tearDown(self) -> None: pass @@ -1136,37 +1134,20 @@ class IntegrationTest(PyMongoTestCase): db: Database credentials: Dict[str, str] - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @client_context.require_connection - def _setup_class(cls): - if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False): + def setUp(self) -> None: + if not _IS_SYNC: + reset_client_context() + if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") - if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False): + if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): raise SkipTest("this test does not support serverless") - cls.client = client_context.client - cls.db = cls.client.pymongo_test + self.client = client_context.client + self.db = self.client.pymongo_test if client_context.auth_enabled: - cls.credentials = {"username": db_user, "password": db_pwd} + self.credentials = {"username": db_user, "password": db_pwd} else: - cls.credentials = {} - - @classmethod - def _tearDown_class(cls): - pass + self.credentials = {} def cleanup_colls(self, *collections): """Cleanup collections faster than drop_collection.""" @@ -1192,37 +1173,14 @@ class MockClientTest(UnitTest): # MockClients tests that use replicaSet, directConnection=True, pass # multiple seed addresses, or wait for heartbeat events are incompatible # with loadBalanced=True. - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @client_context.require_no_load_balancer - def _setup_class(cls): - pass - - @classmethod - def _tearDown_class(cls): - pass - - def setUp(self): + def setUp(self) -> None: super().setUp() self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) - self.client_knobs.enable() - def tearDown(self): + def tearDown(self) -> None: self.client_knobs.disable() super().tearDown() @@ -1253,7 +1211,6 @@ def teardown(): c.drop_database("pymongo_test_mike") c.drop_database("pymongo_test_bernie") c.close() - print_running_clients() diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 0579828c4..73e282474 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio import gc +import logging import multiprocessing import os import signal @@ -25,6 +26,7 @@ import subprocess import sys import threading import time +import traceback import unittest import warnings from asyncio import iscoroutinefunction @@ -191,6 +193,8 @@ class AsyncClientContext: await client.close() async def _init_client(self): + self.mongoses = [] + self.connection_attempts = [] self.client = await self._connect(host, port) if self.client is not None: # Return early when connected to dataLake as mongohoused does not @@ -862,6 +866,16 @@ class AsyncClientContext: async_client_context = AsyncClientContext() +async def reset_client_context(): + if _IS_SYNC: + # sync tests don't need to reset a client context + return + elif async_client_context.client is not None: + await async_client_context.client.close() + async_client_context.client = None + await async_client_context._init_client() + + class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) @@ -1124,26 +1138,10 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): class AsyncUnitTest(AsyncPyMongoTestCase): """Async base class for TestCases that don't require a connection to MongoDB.""" - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod - async def _setup_class(cls): + async def asyncSetUp(self) -> None: pass - @classmethod - async def _tearDown_class(cls): + async def asyncTearDown(self) -> None: pass @@ -1154,37 +1152,20 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): db: AsyncDatabase credentials: Dict[str, str] - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @async_client_context.require_connection - async def _setup_class(cls): - if async_client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False): + async def asyncSetUp(self) -> None: + if not _IS_SYNC: + await reset_client_context() + if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") - if async_client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False): + if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): raise SkipTest("this test does not support serverless") - cls.client = async_client_context.client - cls.db = cls.client.pymongo_test + self.client = async_client_context.client + self.db = self.client.pymongo_test if async_client_context.auth_enabled: - cls.credentials = {"username": db_user, "password": db_pwd} + self.credentials = {"username": db_user, "password": db_pwd} else: - cls.credentials = {} - - @classmethod - async def _tearDown_class(cls): - pass + self.credentials = {} async def cleanup_colls(self, *collections): """Cleanup collections faster than drop_collection.""" @@ -1210,39 +1191,16 @@ class AsyncMockClientTest(AsyncUnitTest): # MockClients tests that use replicaSet, directConnection=True, pass # multiple seed addresses, or wait for heartbeat events are incompatible # with loadBalanced=True. - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @async_client_context.require_no_load_balancer - async def _setup_class(cls): - pass - - @classmethod - async def _tearDown_class(cls): - pass - - def setUp(self): - super().setUp() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) - self.client_knobs.enable() - def tearDown(self): + async def asyncTearDown(self) -> None: self.client_knobs.disable() - super().tearDown() + await super().asyncTearDown() async def async_setup(): @@ -1271,7 +1229,6 @@ async def async_teardown(): await c.drop_database("pymongo_test_mike") await c.drop_database("pymongo_test_bernie") await c.close() - print_running_clients() diff --git a/test/asynchronous/conftest.py b/test/asynchronous/conftest.py index e443dff6c..a27a9f213 100644 --- a/test/asynchronous/conftest.py +++ b/test/asynchronous/conftest.py @@ -22,7 +22,7 @@ def event_loop_policy(): return asyncio.get_event_loop_policy() -@pytest_asyncio.fixture(scope="session", autouse=True) +@pytest_asyncio.fixture(scope="package", autouse=True) async def test_setup_and_teardown(): await async_setup() yield diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index c9ff167b4..7191a412c 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -42,15 +42,11 @@ class AsyncBulkTestBase(AsyncIntegrationTest): coll: AsyncCollection coll_w0: AsyncCollection - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.coll = cls.db.test - cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0)) - async def asyncSetUp(self): - super().setUp() + await super().asyncSetUp() + self.coll = self.db.test await self.coll.drop() + self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0)) def assertEqualResponse(self, expected, actual): """Compare response from bulk.execute() to expected response.""" @@ -787,14 +783,10 @@ class AsyncTestBulk(AsyncBulkTestBase): class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase): - @classmethod @async_client_context.require_auth @async_client_context.require_no_api_version - async def _setup_class(cls): - await super()._setup_class() - async def asyncSetUp(self): - super().setUp() + await super().asyncSetUp() await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"]) await self.db.command( "createRole", @@ -937,21 +929,19 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase): w: Optional[int] secondary: AsyncMongoClient - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.w = async_client_context.w - cls.secondary = None - if cls.w is not None and cls.w > 1: + async def asyncSetUp(self): + await super().asyncSetUp() + self.w = async_client_context.w + self.secondary = None + if self.w is not None and self.w > 1: for member in (await async_client_context.hello)["hosts"]: if member != (await async_client_context.hello)["primary"]: - cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member)) + self.secondary = await self.async_single_client(*partition_node(member)) break - @classmethod - async def async_tearDownClass(cls): - if cls.secondary: - await cls.secondary.close() + async def asyncTearDown(self): + if self.secondary: + await self.secondary.close() async def cause_wtimeout(self, requests, ordered): if not async_client_context.test_commands_enabled: diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 8e16fe752..08da00cc1 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -836,18 +836,16 @@ class ProseSpecTestsMixin: class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin): dbs: list - @classmethod @async_client_context.require_version_min(4, 0, 0, -1) @async_client_context.require_change_streams - async def _setup_class(cls): - await super()._setup_class() - cls.dbs = [cls.db, cls.client.pymongo_test_2] + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.dbs = [self.db, self.client.pymongo_test_2] - @classmethod - async def _tearDown_class(cls): - for db in cls.dbs: - await cls.client.drop_database(db) - await super()._tearDown_class() + async def asyncTearDown(self): + for db in self.dbs: + await self.client.drop_database(db) + await super().asyncTearDown() async def change_stream_with_client(self, client, *args, **kwargs): return await client.watch(*args, **kwargs) @@ -898,11 +896,10 @@ class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin): class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin): - @classmethod @async_client_context.require_version_min(4, 0, 0, -1) @async_client_context.require_change_streams - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() async def change_stream_with_client(self, client, *args, **kwargs): return await client[self.db.name].watch(*args, **kwargs) @@ -988,12 +985,9 @@ class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixi class TestAsyncCollectionAsyncChangeStream( TestAsyncChangeStreamBase, APITestsMixin, ProseSpecTestsMixin ): - @classmethod @async_client_context.require_change_streams - async def _setup_class(cls): - await super()._setup_class() - async def asyncSetUp(self): + await super().asyncSetUp() # Use a new collection for each test. await self.watched_collection().drop() await self.watched_collection().insert_one({}) @@ -1133,20 +1127,11 @@ class TestAllLegacyScenarios(AsyncIntegrationTest): RUN_ON_LOAD_BALANCER = True listener: AllowListEventListener - @classmethod @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() - cls.listener = AllowListEventListener("aggregate", "getMore") - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - await super()._tearDown_class() - - def asyncSetUp(self): - super().asyncSetUp() + async def asyncSetUp(self): + await super().asyncSetUp() + self.listener = AllowListEventListener("aggregate", "getMore") + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) self.listener.reset() async def asyncSetUpCluster(self, scenario_dict): diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 590154b85..db232386e 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -73,7 +73,6 @@ from test.utils import ( is_greenthread_patched, lazy_client_trial, one, - wait_until, ) import bson @@ -131,16 +130,11 @@ class AsyncClientUnitTest(AsyncUnitTest): client: AsyncMongoClient - @classmethod - async def _setup_class(cls): - cls.client = await cls.unmanaged_async_rs_or_single_client( + async def asyncSetUp(self) -> None: + self.client = await self.async_rs_or_single_client( connect=False, serverSelectionTimeoutMS=100 ) - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): self._caplog = caplog @@ -693,8 +687,8 @@ class TestClient(AsyncIntegrationTest): # When the reaper runs at the same time as the get_socket, two # connections could be created and checked into the pool. self.assertGreaterEqual(len(server._pool.conns), 1) - wait_until(lambda: conn not in server._pool.conns, "remove stale socket") - wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") + await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket") + await async_wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): @@ -710,8 +704,8 @@ class TestClient(AsyncIntegrationTest): # When the reaper runs at the same time as the get_socket, # maxPoolSize=1 should prevent two connections from being created. self.assertEqual(1, len(server._pool.conns)) - wait_until(lambda: conn not in server._pool.conns, "remove stale socket") - wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") + await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket") + await async_wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") async def test_max_idle_time_reaper_removes_stale(self): with client_knobs(kill_cursor_frequency=0.1): @@ -727,7 +721,7 @@ class TestClient(AsyncIntegrationTest): async with server._pool.checkout() as conn_two: pass self.assertIs(conn_one, conn_two) - wait_until( + await async_wait_until( lambda: len(server._pool.conns) == 0, "stale socket reaped and new one NOT added to the pool", ) @@ -745,7 +739,7 @@ class TestClient(AsyncIntegrationTest): server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) - wait_until( + await async_wait_until( lambda: len(server._pool.conns) == 10, "pool initialized with 10 connections", ) @@ -753,7 +747,7 @@ class TestClient(AsyncIntegrationTest): # Assert that if a socket is closed, a new one takes its place async with server._pool.checkout() as conn: conn.close_conn(None) - wait_until( + await async_wait_until( lambda: len(server._pool.conns) == 10, "a closed socket gets replaced from the pool", ) @@ -939,8 +933,10 @@ class TestClient(AsyncIntegrationTest): async with eval(the_repr) as client_two: self.assertEqual(client_two, client) - def test_getters(self): - wait_until(lambda: async_client_context.nodes == self.client.nodes, "find all nodes") + async def test_getters(self): + await async_wait_until( + lambda: async_client_context.nodes == self.client.nodes, "find all nodes" + ) async def test_list_databases(self): cmd_docs = (await self.client.admin.command("listDatabases"))["databases"] @@ -1065,14 +1061,21 @@ class TestClient(AsyncIntegrationTest): self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. - kc_thread = client._kill_cursors_executor._thread - self.assertFalse(kc_thread and kc_thread.is_alive()) - + if _IS_SYNC: + kc_thread = client._kill_cursors_executor._thread + self.assertFalse(kc_thread and kc_thread.is_alive()) + else: + kc_task = client._kill_cursors_executor._task + self.assertFalse(kc_task and not kc_task.done()) # Using the client should open topology and start the thread. await client.admin.command("ping") self.assertTrue(client._topology._opened) - kc_thread = client._kill_cursors_executor._thread - self.assertTrue(kc_thread and kc_thread.is_alive()) + if _IS_SYNC: + kc_thread = client._kill_cursors_executor._thread + self.assertTrue(kc_thread and kc_thread.is_alive()) + else: + kc_task = client._kill_cursors_executor._task + self.assertTrue(kc_task and not kc_task.done()) async def test_close_does_not_open_servers(self): client = await self.async_rs_client(connect=False) @@ -1277,6 +1280,7 @@ class TestClient(AsyncIntegrationTest): async def test_server_selection_timeout(self): client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + await client.close() client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False) @@ -1289,18 +1293,22 @@ class TestClient(AsyncIntegrationTest): self.assertRaises( ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False ) + await client.close() client = AsyncMongoClient( "mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False ) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + await client.close() client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) self.assertAlmostEqual(0, client.options.server_selection_timeout) + await client.close() # Test invalid timeout in URI ignored and set to default. client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) self.assertAlmostEqual(30, client.options.server_selection_timeout) + await client.close() client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) self.assertAlmostEqual(30, client.options.server_selection_timeout) @@ -1608,7 +1616,7 @@ class TestClient(AsyncIntegrationTest): await async_client_context.port, ) await self.async_single_client(uri, event_listeners=[listener]) - wait_until( + await async_wait_until( lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" ) @@ -1766,16 +1774,16 @@ class TestClient(AsyncIntegrationTest): pool = await async_get_pool(client) original_connect = pool.connect - def stall_connect(*args, **kwargs): - time.sleep(2) - return original_connect(*args, **kwargs) + async def stall_connect(*args, **kwargs): + await asyncio.sleep(2) + return await original_connect(*args, **kwargs) pool.connect = stall_connect # Un-patch Pool.connect to break the cyclic reference. self.addCleanup(delattr, pool, "connect") # Wait for the background thread to start creating connections - wait_until(lambda: len(pool.conns) > 1, "start creating connections") + await async_wait_until(lambda: len(pool.conns) > 1, "start creating connections") # Assert that application operations do not block. for _ in range(10): @@ -1858,7 +1866,7 @@ class TestClient(AsyncIntegrationTest): await client.close() # Add cursor to kill cursors queue del cursor - wait_until( + await async_wait_until( lambda: client._kill_cursors_queue, "waited for cursor to be added to queue", ) @@ -2232,7 +2240,7 @@ class TestExhaustCursor(AsyncIntegrationTest): await cursor.to_list() self.assertTrue(conn.closed) - wait_until( + await async_wait_until( lambda: len(client._kill_cursors_queue) == 0, "waited for all killCursor requests to complete", ) @@ -2403,7 +2411,7 @@ class TestMongoClientFailover(AsyncMockClientTest): ) self.addAsyncCleanup(c.close) - wait_until(lambda: len(c.nodes) == 3, "connect") + await async_wait_until(lambda: len(c.nodes) == 3, "connect") self.assertEqual(await c.address, ("a", 1)) # Fail over. @@ -2430,7 +2438,7 @@ class TestMongoClientFailover(AsyncMockClientTest): ) self.addAsyncCleanup(c.close) - wait_until(lambda: len(c.nodes) == 3, "connect") + await async_wait_until(lambda: len(c.nodes) == 3, "connect") # Total failure. c.kill_host("a:1") @@ -2472,7 +2480,7 @@ class TestMongoClientFailover(AsyncMockClientTest): c.set_wire_version_range("a:1", 2, MIN_SUPPORTED_WIRE_VERSION) c.set_wire_version_range("b:2", 2, MIN_SUPPORTED_WIRE_VERSION + 1) await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST) - wait_until(lambda: len(c.nodes) == 2, "connect") + await async_wait_until(lambda: len(c.nodes) == 2, "connect") c.kill_host("a:1") @@ -2544,11 +2552,11 @@ class TestClientPool(AsyncMockClientTest): ) self.addAsyncCleanup(c.close) - wait_until(lambda: len(c.nodes) == 3, "connect") + await async_wait_until(lambda: len(c.nodes) == 3, "connect") self.assertEqual(await c.address, ("a", 1)) self.assertEqual(await c.arbiters, {("c", 3)}) # Assert that we create 2 and only 2 pooled connections. - listener.wait_for_event(monitoring.ConnectionReadyEvent, 2) + await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 2) self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2) # Assert that we do not create connections to arbiters. arbiter = c._topology.get_server_by_address(("c", 3)) @@ -2574,10 +2582,10 @@ class TestClientPool(AsyncMockClientTest): ) self.addAsyncCleanup(c.close) - wait_until(lambda: len(c.nodes) == 1, "connect") + await async_wait_until(lambda: len(c.nodes) == 1, "connect") self.assertEqual(await c.address, ("c", 3)) # Assert that we create 1 pooled connection. - listener.wait_for_event(monitoring.ConnectionReadyEvent, 1) + await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 1) self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1) arbiter = c._topology.get_server_by_address(("c", 3)) self.assertEqual(len(arbiter.pool.conns), 1) diff --git a/test/asynchronous/test_collation.py b/test/asynchronous/test_collation.py index d95f4c991..d7fd85b16 100644 --- a/test/asynchronous/test_collation.py +++ b/test/asynchronous/test_collation.py @@ -97,28 +97,21 @@ class TestCollation(AsyncIntegrationTest): warn_context: Any collation: Collation - @classmethod @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() - cls.listener = OvertCommandListener() - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - cls.db = cls.client.pymongo_test - cls.collation = Collation("en_US") - cls.warn_context = warnings.catch_warnings() - cls.warn_context.__enter__() - warnings.simplefilter("ignore", DeprecationWarning) + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.listener = OvertCommandListener() + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) + self.db = self.client.pymongo_test + self.collation = Collation("en_US") + self.warn_context = warnings.catch_warnings() + self.warn_context.__enter__() - @classmethod - async def _tearDown_class(cls): - cls.warn_context.__exit__() - cls.warn_context = None - await cls.client.close() - await super()._tearDown_class() - - def tearDown(self): + async def asyncTearDown(self) -> None: + self.warn_context.__exit__() + self.warn_context = None self.listener.reset() - super().tearDown() + await super().asyncTearDown() def last_command_started(self): return self.listener.started_events[-1].command diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index db52bad4a..528919f63 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -40,7 +40,6 @@ from test.utils import ( async_get_pool, async_is_mongos, async_wait_until, - wait_until, ) from bson import encode @@ -88,14 +87,10 @@ class TestCollectionNoConnect(AsyncUnitTest): db: AsyncDatabase client: AsyncMongoClient - @classmethod - async def _setup_class(cls): - cls.client = AsyncMongoClient(connect=False) - cls.db = cls.client.pymongo_test - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.client = self.simple_client(connect=False) + self.db = self.client.pymongo_test def test_collection(self): self.assertRaises(TypeError, AsyncCollection, self.db, 5) @@ -165,27 +160,14 @@ class TestCollectionNoConnect(AsyncUnitTest): class AsyncTestCollection(AsyncIntegrationTest): w: int - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.w = async_client_context.w # type: ignore - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine] - else: - asyncio.run(cls.async_tearDownClass()) - - @classmethod - async def async_tearDownClass(cls): - await cls.db.drop_collection("test_large_limit") - async def asyncSetUp(self): - await self.db.test.drop() + await super().asyncSetUp() + self.w = async_client_context.w # type: ignore async def asyncTearDown(self): await self.db.test.drop() + await self.db.drop_collection("test_large_limit") + await super().asyncTearDown() @contextlib.contextmanager def write_concern_collection(self): @@ -1023,7 +1005,10 @@ class AsyncTestCollection(AsyncIntegrationTest): await db.test.insert_one({"y": 1}, bypass_document_validation=True) await db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True) - await async_wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document") + async def predicate(): + return await db_w0.test.find_one({"x": 1}) + + await async_wait_until(predicate, "find w:0 replaced document") async def test_update_bypass_document_validation(self): db = self.db @@ -1871,7 +1856,7 @@ class AsyncTestCollection(AsyncIntegrationTest): await cur.close() cur = None # Wait until the background thread returns the socket. - wait_until(lambda: pool.active_sockets == 0, "return socket") + await async_wait_until(lambda: pool.active_sockets == 0, "return socket") # The socket should be discarded. self.assertEqual(0, len(pool.conns)) diff --git a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py index 289cf4975..bc9638b44 100644 --- a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py +++ b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py @@ -19,7 +19,12 @@ import sys sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous import ( + AsyncIntegrationTest, + async_client_context, + reset_client_context, + unittest, +) from test.asynchronous.helpers import async_repl_set_step_down from test.utils import ( CMAPListener, @@ -39,29 +44,19 @@ class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest): listener: CMAPListener coll: AsyncCollection - @classmethod @async_client_context.require_replica_set - async def _setup_class(cls): - await super()._setup_class() - cls.listener = CMAPListener() - cls.client = await cls.unmanaged_async_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 + async def asyncSetUp(self): + self.listener = CMAPListener() + self.client = await self.async_rs_or_single_client( + event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500 ) # Ensure connections to all servers in replica set. This is to test # that the is_writable flag is properly updated for connections that # survive a replica set election. - await async_ensure_all_connected(cls.client) - cls.listener.reset() - - cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority")) - cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - - async def asyncSetUp(self): + await async_ensure_all_connected(self.client) + self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority")) + self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority")) # Note that all ops use same write-concern as self.db (majority). await self.db.drop_collection("step-down") await self.db.create_collection("step-down") diff --git a/test/asynchronous/test_create_entities.py b/test/asynchronous/test_create_entities.py index cb2ec63f4..1f68cf6dd 100644 --- a/test/asynchronous/test_create_entities.py +++ b/test/asynchronous/test_create_entities.py @@ -56,6 +56,9 @@ class TestCreateEntities(AsyncIntegrationTest): self.assertGreater(len(final_entity_map["events1"]), 0) for event in final_entity_map["events1"]: self.assertIn("PoolCreatedEvent", event["name"]) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + await client.close() async def test_store_all_others_as_entities(self): self.scenario_runner = UnifiedSpecTestMixinV1() @@ -122,6 +125,9 @@ class TestCreateEntities(AsyncIntegrationTest): self.assertEqual(entity_map["failures"], []) self.assertEqual(entity_map["successes"], 2) self.assertEqual(entity_map["iterations"], 5) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + await client.close() if __name__ == "__main__": diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 787da3d95..d21647945 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -34,9 +34,9 @@ from test.utils import ( AllowListEventListener, EventListener, OvertCommandListener, + async_wait_until, delay, ignore_deprecations, - wait_until, ) from bson import decode_all @@ -1324,8 +1324,8 @@ class TestCursor(AsyncIntegrationTest): with self.assertRaises(ExecutionTimeout): await cursor.next() - def assertCursorKilled(): - wait_until( + async def assertCursorKilled(): + await async_wait_until( lambda: len(listener.succeeded_events), "find successful killCursors command", ) @@ -1335,7 +1335,7 @@ class TestCursor(AsyncIntegrationTest): self.assertEqual(1, len(listener.succeeded_events)) self.assertEqual("killCursors", listener.succeeded_events[0].command_name) - assertCursorKilled() + await assertCursorKilled() listener.reset() cursor = await coll.aggregate([], batchSize=1) @@ -1345,7 +1345,7 @@ class TestCursor(AsyncIntegrationTest): with self.assertRaises(ExecutionTimeout): await cursor.next() - assertCursorKilled() + await assertCursorKilled() def test_delete_not_initialized(self): # Creating a cursor with invalid arguments will not run __init__ @@ -1647,10 +1647,6 @@ class TestRawBatchCursor(AsyncIntegrationTest): class TestRawBatchCommandCursor(AsyncIntegrationTest): - @classmethod - async def _setup_class(cls): - await super()._setup_class() - async def test_aggregate_raw(self): c = self.db.test await c.drop() diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index 61369c854..b5a596042 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -717,7 +717,8 @@ class TestDatabase(AsyncIntegrationTest): class TestDatabaseAggregation(AsyncIntegrationTest): - def setUp(self): + async def asyncSetUp(self): + await super().asyncSetUp() self.pipeline: List[Mapping[str, Any]] = [ {"$listLocalSessions": {}}, {"$limit": 1}, diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 767b3ecf0..048db2d50 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -211,11 +211,10 @@ class TestClientOptions(AsyncPyMongoTestCase): class AsyncEncryptionIntegrationTest(AsyncIntegrationTest): """Base class for encryption integration tests.""" - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @async_client_context.require_version_min(4, 2, -1) - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() def assertEncrypted(self, val): self.assertIsInstance(val, Binary) @@ -430,10 +429,9 @@ class TestEncryptedBulkWrite(AsyncBulkTestBase, AsyncEncryptionIntegrationTest): class TestClientMaxWireVersion(AsyncIntegrationTest): - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() @async_client_context.require_version_max(4, 0, 99) async def test_raise_max_wire_version_error(self): @@ -818,17 +816,16 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest): "local": None, } - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - async def _setup_class(cls): - await super()._setup_class() - cls.listener = OvertCommandListener() - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - await cls.client.db.coll.drop() - cls.vault = await create_key_vault(cls.client.keyvault.datakeys) + async def asyncSetUp(self): + await super().asyncSetUp() + self.listener = OvertCommandListener() + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) + await self.client.db.coll.drop() + self.vault = await create_key_vault(self.client.keyvault.datakeys) # Configure the encrypted field via the local schema_map option. schemas = { @@ -846,25 +843,22 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest): } } opts = AutoEncryptionOpts( - cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS + self.KMS_PROVIDERS, + "keyvault.datakeys", + schema_map=schemas, + kms_tls_options=KMS_TLS_OPTS, ) - cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( + self.client_encrypted = await self.async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - cls.client_encryption = cls.unmanaged_create_client_encryption( - cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS + self.client_encryption = self.create_client_encryption( + self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS ) - - @classmethod - async def _tearDown_class(cls): - await cls.vault.drop() - await cls.client.close() - await cls.client_encrypted.close() - await cls.client_encryption.close() - - def setUp(self): self.listener.reset() + async def asyncTearDown(self) -> None: + await self.vault.drop() + async def run_test(self, provider_name): # Create data key. master_key: Any = self.MASTER_KEYS[provider_name] @@ -1011,10 +1005,9 @@ class TestViews(AsyncEncryptionIntegrationTest): class TestCorpus(AsyncEncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() @staticmethod def kms_providers(): @@ -1188,12 +1181,11 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest): client_encrypted: AsyncMongoClient listener: OvertCommandListener - @classmethod - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() db = async_client_context.client.db - cls.coll = db.coll - await cls.coll.drop() + self.coll = db.coll + await self.coll.drop() # Configure the encrypted 'db.coll' collection via jsonSchema. json_schema = json_data("limits", "limits-schema.json") await db.create_collection( @@ -1211,17 +1203,14 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest): await coll.insert_one(json_data("limits", "limits-key.json")) opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") - cls.listener = OvertCommandListener() - cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( - auto_encryption_opts=opts, event_listeners=[cls.listener] + self.listener = OvertCommandListener() + self.client_encrypted = await self.async_rs_or_single_client( + auto_encryption_opts=opts, event_listeners=[self.listener] ) - cls.coll_encrypted = cls.client_encrypted.db.coll + self.coll_encrypted = self.client_encrypted.db.coll - @classmethod - async def _tearDown_class(cls): - await cls.coll_encrypted.drop() - await cls.client_encrypted.close() - await super()._tearDown_class() + async def asyncTearDown(self) -> None: + await self.coll_encrypted.drop() async def test_01_insert_succeeds_under_2MiB(self): doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB} @@ -1245,7 +1234,9 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest): doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB} self.listener.reset() await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) - self.assertEqual(self.listener.started_command_names(), ["insert", "insert"]) + self.assertEqual( + len([c for c in self.listener.started_command_names() if c == "insert"]), 2 + ) async def test_04_bulk_batch_split(self): limits_doc = json_data("limits", "limits-doc.json") @@ -1255,7 +1246,9 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest): doc2.update(limits_doc) self.listener.reset() await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) - self.assertEqual(self.listener.started_command_names(), ["insert", "insert"]) + self.assertEqual( + len([c for c in self.listener.started_command_names() if c == "insert"]), 2 + ) async def test_05_insert_succeeds_just_under_16MiB(self): doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)} @@ -1285,15 +1278,12 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest): class TestCustomEndpoint(AsyncEncryptionIntegrationTest): """Prose tests for creating data keys with a custom endpoint.""" - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - async def _setup_class(cls): - await super()._setup_class() - - def setUp(self): + async def asyncSetUp(self): + await super().asyncSetUp() kms_providers = { "aws": AWS_CREDS, "azure": AZURE_CREDS, @@ -1322,10 +1312,6 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest): self._kmip_host_error = None self._invalid_host_error = None - async def asyncTearDown(self): - await self.client_encryption.close() - await self.client_encryption_invalid.close() - async def run_test_expected_success(self, provider_name, master_key): data_key_id = await self.client_encryption.create_data_key( provider_name, master_key=master_key @@ -1500,18 +1486,18 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest): KEYVAULT_COLL = "datakeys" client: AsyncMongoClient - async def asyncSetUp(self): + async def _setup(self): keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL) await create_key_vault(keyvault, self.DEK) async def _test_explicit(self, expectation): + await self._setup() client_encryption = self.create_client_encryption( self.KMS_PROVIDER_MAP, # type: ignore[arg-type] ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), async_client_context.client, OPTS, ) - self.addAsyncCleanup(client_encryption.close) ciphertext = await client_encryption.encrypt( "string0", @@ -1523,6 +1509,7 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest): self.assertEqual(await client_encryption.decrypt(ciphertext), "string0") async def _test_automatic(self, expectation_extjson, payload): + await self._setup() encrypted_db = "db" encrypted_coll = "coll" keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]) @@ -1537,7 +1524,6 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest): client = await self.async_rs_or_single_client( auto_encryption_opts=encryption_opts, event_listeners=[insert_listener] ) - self.addAsyncCleanup(client.aclose) coll = client.get_database(encrypted_db).get_collection( encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority") @@ -1559,13 +1545,12 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest): class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set") - async def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} - cls.DEK = json_data(BASE, "custom", "azure-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - await super()._setup_class() + async def asyncSetUp(self): + self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} + self.DEK = json_data(BASE, "custom", "azure-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + await super().asyncSetUp() async def test_explicit(self): return await self._test_explicit( @@ -1585,13 +1570,12 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegratio class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set") - async def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} - cls.DEK = json_data(BASE, "custom", "gcp-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - await super()._setup_class() + async def asyncSetUp(self): + self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} + self.DEK = json_data(BASE, "custom", "gcp-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + await super().asyncSetUp() async def test_explicit(self): return await self._test_explicit( @@ -1613,6 +1597,7 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationT # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests class TestDeadlockProse(AsyncEncryptionIntegrationTest): async def asyncSetUp(self): + await super().asyncSetUp() self.client_test = await self.async_rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" ) @@ -1645,7 +1630,6 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest): self.ciphertext = await client_encryption.encrypt( "string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local" ) - await client_encryption.close() self.client_listener = OvertCommandListener() self.topology_listener = TopologyEventListener() @@ -1840,6 +1824,7 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest): # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events class TestDecryptProse(AsyncEncryptionIntegrationTest): async def asyncSetUp(self): + await super().asyncSetUp() self.client = async_client_context.client await self.client.db.drop_collection("decryption_events") await create_key_vault(self.client.keyvault.datakeys) @@ -2275,6 +2260,7 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest): # https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames class TestUniqueIndexOnKeyAltNamesProse(AsyncEncryptionIntegrationTest): async def asyncSetUp(self): + await super().asyncSetUp() self.client = async_client_context.client await create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} @@ -2624,8 +2610,6 @@ class TestQueryableEncryptionDocsExample(AsyncEncryptionIntegrationTest): assert isinstance(res["encrypted_indexed"], Binary) assert isinstance(res["encrypted_unindexed"], Binary) - await client_encryption.close() - # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption class TestRangeQueryProse(AsyncEncryptionIntegrationTest): @@ -3089,17 +3073,11 @@ class TestNoSessionsSupport(AsyncEncryptionIntegrationTest): mongocryptd_client: AsyncMongoClient MONGOCRYPTD_PORT = 27020 - @classmethod @unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") - async def _setup_class(cls): - await super()._setup_class() - start_mongocryptd(cls.MONGOCRYPTD_PORT) - - @classmethod - async def _tearDown_class(cls): - await super()._tearDown_class() - async def asyncSetUp(self) -> None: + await super().asyncSetUp() + start_mongocryptd(self.MONGOCRYPTD_PORT) + self.listener = OvertCommandListener() self.mongocryptd_client = self.simple_client( f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] diff --git a/test/asynchronous/test_grid_file.py b/test/asynchronous/test_grid_file.py index 54fcd3abf..affdacde9 100644 --- a/test/asynchronous/test_grid_file.py +++ b/test/asynchronous/test_grid_file.py @@ -97,6 +97,7 @@ class AsyncTestGridFileNoConnect(AsyncUnitTest): class AsyncTestGridFile(AsyncIntegrationTest): async def asyncSetUp(self): + await super().asyncSetUp() await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks) async def test_basic(self): diff --git a/test/asynchronous/test_locks.py b/test/asynchronous/test_locks.py index e0e7f2fc8..e5a0adfee 100644 --- a/test/asynchronous/test_locks.py +++ b/test/asynchronous/test_locks.py @@ -16,498 +16,447 @@ from __future__ import annotations import asyncio import sys -import threading import unittest +from pymongo.lock import _async_create_condition, _async_create_lock + sys.path[0:0] = [""] -from pymongo.lock import _ACondition +if sys.version_info < (3, 13): + # Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py + # Includes tests for: + # - https://github.com/python/cpython/issues/111693 + # - https://github.com/python/cpython/issues/112202 + class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): + async def test_wait(self): + cond = _async_create_condition(_async_create_lock()) + result = [] - -# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py -# Includes tests for: -# - https://github.com/python/cpython/issues/111693 -# - https://github.com/python/cpython/issues/112202 -class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): - async def test_wait(self): - cond = _ACondition(threading.Condition(threading.Lock())) - result = [] - - async def c1(result): - await cond.acquire() - if await cond.wait(): - result.append(1) - return True - - async def c2(result): - await cond.acquire() - if await cond.wait(): - result.append(2) - return True - - async def c3(result): - await cond.acquire() - if await cond.wait(): - result.append(3) - return True - - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) - t3 = asyncio.create_task(c3(result)) - - await asyncio.sleep(0) - self.assertEqual([], result) - self.assertFalse(cond.locked()) - - self.assertTrue(await cond.acquire()) - cond.notify() - await asyncio.sleep(0) - self.assertEqual([], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1], result) - self.assertTrue(cond.locked()) - - cond.notify(2) - await asyncio.sleep(0) - self.assertEqual([1], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1, 2], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1, 2, 3], result) - self.assertTrue(cond.locked()) - - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) - self.assertTrue(t3.done()) - self.assertTrue(t3.result()) - - async def test_wait_cancel(self): - cond = _ACondition(threading.Condition(threading.Lock())) - await cond.acquire() - - wait = asyncio.create_task(cond.wait()) - asyncio.get_running_loop().call_soon(wait.cancel) - with self.assertRaises(asyncio.CancelledError): - await wait - self.assertFalse(cond._waiters) - self.assertTrue(cond.locked()) - - async def test_wait_cancel_contested(self): - cond = _ACondition(threading.Condition(threading.Lock())) - - await cond.acquire() - self.assertTrue(cond.locked()) - - wait_task = asyncio.create_task(cond.wait()) - await asyncio.sleep(0) - self.assertFalse(cond.locked()) - - # Notify, but contest the lock before cancelling - await cond.acquire() - self.assertTrue(cond.locked()) - cond.notify() - asyncio.get_running_loop().call_soon(wait_task.cancel) - asyncio.get_running_loop().call_soon(cond.release) - - try: - await wait_task - except asyncio.CancelledError: - # Should not happen, since no cancellation points - pass - - self.assertTrue(cond.locked()) - - async def test_wait_cancel_after_notify(self): - # See bpo-32841 - waited = False - - cond = _ACondition(threading.Condition(threading.Lock())) - - async def wait_on_cond(): - nonlocal waited - async with cond: - waited = True # Make sure this area was reached - await cond.wait() - - waiter = asyncio.create_task(wait_on_cond()) - await asyncio.sleep(0) # Start waiting - - await cond.acquire() - cond.notify() - await asyncio.sleep(0) # Get to acquire() - waiter.cancel() - await asyncio.sleep(0) # Activate cancellation - cond.release() - await asyncio.sleep(0) # Cancellation should occur - - self.assertTrue(waiter.cancelled()) - self.assertTrue(waited) - - async def test_wait_unacquired(self): - cond = _ACondition(threading.Condition(threading.Lock())) - with self.assertRaises(RuntimeError): - await cond.wait() - - async def test_wait_for(self): - cond = _ACondition(threading.Condition(threading.Lock())) - presult = False - - def predicate(): - return presult - - result = [] - - async def c1(result): - await cond.acquire() - if await cond.wait_for(predicate): - result.append(1) - cond.release() - return True - - t = asyncio.create_task(c1(result)) - - await asyncio.sleep(0) - self.assertEqual([], result) - - await cond.acquire() - cond.notify() - cond.release() - await asyncio.sleep(0) - self.assertEqual([], result) - - presult = True - await cond.acquire() - cond.notify() - cond.release() - await asyncio.sleep(0) - self.assertEqual([1], result) - - self.assertTrue(t.done()) - self.assertTrue(t.result()) - - async def test_wait_for_unacquired(self): - cond = _ACondition(threading.Condition(threading.Lock())) - - # predicate can return true immediately - res = await cond.wait_for(lambda: [1, 2, 3]) - self.assertEqual([1, 2, 3], res) - - with self.assertRaises(RuntimeError): - await cond.wait_for(lambda: False) - - async def test_notify(self): - cond = _ACondition(threading.Condition(threading.Lock())) - result = [] - - async def c1(result): - async with cond: + async def c1(result): + await cond.acquire() if await cond.wait(): result.append(1) return True - async def c2(result): - async with cond: + async def c2(result): + await cond.acquire() if await cond.wait(): result.append(2) return True - async def c3(result): - async with cond: + async def c3(result): + await cond.acquire() if await cond.wait(): result.append(3) return True - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) - t3 = asyncio.create_task(c3(result)) + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) - await asyncio.sleep(0) - self.assertEqual([], result) - - async with cond: - cond.notify(1) - await asyncio.sleep(1) - self.assertEqual([1], result) - - async with cond: - cond.notify(1) - cond.notify(2048) - await asyncio.sleep(1) - self.assertEqual([1, 2, 3], result) - - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) - self.assertTrue(t3.done()) - self.assertTrue(t3.result()) - - async def test_notify_all(self): - cond = _ACondition(threading.Condition(threading.Lock())) - - result = [] - - async def c1(result): - async with cond: - if await cond.wait(): - result.append(1) - return True - - async def c2(result): - async with cond: - if await cond.wait(): - result.append(2) - return True - - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) - - await asyncio.sleep(0) - self.assertEqual([], result) - - async with cond: - cond.notify_all() - await asyncio.sleep(1) - self.assertEqual([1, 2], result) - - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) - - async def test_context_manager(self): - cond = _ACondition(threading.Condition(threading.Lock())) - self.assertFalse(cond.locked()) - async with cond: - self.assertTrue(cond.locked()) - self.assertFalse(cond.locked()) - - async def test_timeout_in_block(self): - condition = _ACondition(threading.Condition(threading.Lock())) - async with condition: - with self.assertRaises(asyncio.TimeoutError): - await asyncio.wait_for(condition.wait(), timeout=0.5) - - @unittest.skipIf( - sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" - ) - async def test_cancelled_error_wakeup(self): - # Test that a cancelled error, received when awaiting wakeup, - # will be re-raised un-modified. - wake = False - raised = None - cond = _ACondition(threading.Condition(threading.Lock())) - - async def func(): - nonlocal raised - async with cond: - with self.assertRaises(asyncio.CancelledError) as err: - await cond.wait_for(lambda: wake) - raised = err.exception - raise raised - - task = asyncio.create_task(func()) - await asyncio.sleep(0) - # Task is waiting on the condition, cancel it there. - task.cancel(msg="foo") # type: ignore[call-arg] - with self.assertRaises(asyncio.CancelledError) as err: - await task - self.assertEqual(err.exception.args, ("foo",)) - # We should have got the _same_ exception instance as the one - # originally raised. - self.assertIs(err.exception, raised) - - @unittest.skipIf( - sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" - ) - async def test_cancelled_error_re_aquire(self): - # Test that a cancelled error, received when re-aquiring lock, - # will be re-raised un-modified. - wake = False - raised = None - cond = _ACondition(threading.Condition(threading.Lock())) - - async def func(): - nonlocal raised - async with cond: - with self.assertRaises(asyncio.CancelledError) as err: - await cond.wait_for(lambda: wake) - raised = err.exception - raise raised - - task = asyncio.create_task(func()) - await asyncio.sleep(0) - # Task is waiting on the condition - await cond.acquire() - wake = True - cond.notify() - await asyncio.sleep(0) - # Task is now trying to re-acquire the lock, cancel it there. - task.cancel(msg="foo") # type: ignore[call-arg] - cond.release() - with self.assertRaises(asyncio.CancelledError) as err: - await task - self.assertEqual(err.exception.args, ("foo",)) - # We should have got the _same_ exception instance as the one - # originally raised. - self.assertIs(err.exception, raised) - - @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") - async def test_cancelled_wakeup(self): - # Test that a task cancelled at the "same" time as it is woken - # up as part of a Condition.notify() does not result in a lost wakeup. - # This test simulates a cancel while the target task is awaiting initial - # wakeup on the wakeup queue. - condition = _ACondition(threading.Condition(threading.Lock())) - state = 0 - - async def consumer(): - nonlocal state - async with condition: - while True: - await condition.wait_for(lambda: state != 0) - if state < 0: - return - state -= 1 - - # create two consumers - c = [asyncio.create_task(consumer()) for _ in range(2)] - # wait for them to settle - await asyncio.sleep(0.1) - async with condition: - # produce one item and wake up one - state += 1 - condition.notify(1) - - # Cancel it while it is awaiting to be run. - # This cancellation could come from the outside - c[0].cancel() - - # now wait for the item to be consumed - # if it doesn't means that our "notify" didn"t take hold. - # because it raced with a cancel() - try: - async with asyncio.timeout(1): - await condition.wait_for(lambda: state == 0) - except TimeoutError: - pass - self.assertEqual(state, 0) - - # clean up - state = -1 - condition.notify_all() - await c[1] - - @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") - async def test_cancelled_wakeup_relock(self): - # Test that a task cancelled at the "same" time as it is woken - # up as part of a Condition.notify() does not result in a lost wakeup. - # This test simulates a cancel while the target task is acquiring the lock - # again. - condition = _ACondition(threading.Condition(threading.Lock())) - state = 0 - - async def consumer(): - nonlocal state - async with condition: - while True: - await condition.wait_for(lambda: state != 0) - if state < 0: - return - state -= 1 - - # create two consumers - c = [asyncio.create_task(consumer()) for _ in range(2)] - # wait for them to settle - await asyncio.sleep(0.1) - async with condition: - # produce one item and wake up one - state += 1 - condition.notify(1) - - # now we sleep for a bit. This allows the target task to wake up and - # settle on re-aquiring the lock await asyncio.sleep(0) + self.assertEqual([], result) + self.assertFalse(cond.locked()) - # Cancel it while awaiting the lock - # This cancel could come the outside. - c[0].cancel() + self.assertTrue(await cond.acquire()) + cond.notify() + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_wait_cancel(self): + cond = _async_create_condition(_async_create_lock()) + await cond.acquire() + + wait = asyncio.create_task(cond.wait()) + asyncio.get_running_loop().call_soon(wait.cancel) + with self.assertRaises(asyncio.CancelledError): + await wait + self.assertFalse(cond._waiters) + self.assertTrue(cond.locked()) + + async def test_wait_cancel_contested(self): + cond = _async_create_condition(_async_create_lock()) + + await cond.acquire() + self.assertTrue(cond.locked()) + + wait_task = asyncio.create_task(cond.wait()) + await asyncio.sleep(0) + self.assertFalse(cond.locked()) + + # Notify, but contest the lock before cancelling + await cond.acquire() + self.assertTrue(cond.locked()) + cond.notify() + asyncio.get_running_loop().call_soon(wait_task.cancel) + asyncio.get_running_loop().call_soon(cond.release) - # now wait for the item to be consumed - # if it doesn't means that our "notify" didn"t take hold. - # because it raced with a cancel() try: - async with asyncio.timeout(1): - await condition.wait_for(lambda: state == 0) - except TimeoutError: + await wait_task + except asyncio.CancelledError: + # Should not happen, since no cancellation points pass - self.assertEqual(state, 0) - # clean up - state = -1 - condition.notify_all() - await c[1] + self.assertTrue(cond.locked()) + async def test_wait_cancel_after_notify(self): + # See bpo-32841 + waited = False -class TestCondition(unittest.IsolatedAsyncioTestCase): - async def test_multiple_loops_notify(self): - cond = _ACondition(threading.Condition(threading.Lock())) + cond = _async_create_condition(_async_create_lock()) - def tmain(cond): - async def atmain(cond): - await asyncio.sleep(1) + async def wait_on_cond(): + nonlocal waited async with cond: - cond.notify(1) + waited = True # Make sure this area was reached + await cond.wait() - asyncio.run(atmain(cond)) + waiter = asyncio.create_task(wait_on_cond()) + await asyncio.sleep(0) # Start waiting - t = threading.Thread(target=tmain, args=(cond,)) - t.start() + await cond.acquire() + cond.notify() + await asyncio.sleep(0) # Get to acquire() + waiter.cancel() + await asyncio.sleep(0) # Activate cancellation + cond.release() + await asyncio.sleep(0) # Cancellation should occur - async with cond: - self.assertTrue(await cond.wait(30)) - t.join() + self.assertTrue(waiter.cancelled()) + self.assertTrue(waited) - async def test_multiple_loops_notify_all(self): - cond = _ACondition(threading.Condition(threading.Lock())) - results = [] + async def test_wait_unacquired(self): + cond = _async_create_condition(_async_create_lock()) + with self.assertRaises(RuntimeError): + await cond.wait() - def tmain(cond, results): - async def atmain(cond, results): - await asyncio.sleep(1) + async def test_wait_for(self): + cond = _async_create_condition(_async_create_lock()) + presult = False + + def predicate(): + return presult + + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait_for(predicate): + result.append(1) + cond.release() + return True + + t = asyncio.create_task(c1(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([], result) + + presult = True + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + async def test_wait_for_unacquired(self): + cond = _async_create_condition(_async_create_lock()) + + # predicate can return true immediately + res = await cond.wait_for(lambda: [1, 2, 3]) + self.assertEqual([1, 2, 3], res) + + with self.assertRaises(RuntimeError): + await cond.wait_for(lambda: False) + + async def test_notify(self): + cond = _async_create_condition(_async_create_lock()) + result = [] + + async def c1(result): async with cond: - res = await cond.wait(30) - results.append(res) + if await cond.wait(): + result.append(1) + return True - asyncio.run(atmain(cond, results)) + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True - nthreads = 5 - threads = [] - for _ in range(nthreads): - threads.append(threading.Thread(target=tmain, args=(cond, results))) - for t in threads: - t.start() + async def c3(result): + async with cond: + if await cond.wait(): + result.append(3) + return True - await asyncio.sleep(2) - async with cond: - cond.notify_all() + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) - for t in threads: - t.join() + await asyncio.sleep(0) + self.assertEqual([], result) - self.assertEqual(results, [True] * nthreads) + async with cond: + cond.notify(1) + await asyncio.sleep(1) + self.assertEqual([1], result) + async with cond: + cond.notify(1) + cond.notify(2048) + await asyncio.sleep(1) + self.assertEqual([1, 2, 3], result) -if __name__ == "__main__": - unittest.main() + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_notify_all(self): + cond = _async_create_condition(_async_create_lock()) + + result = [] + + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + async with cond: + cond.notify_all() + await asyncio.sleep(1) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + async def test_context_manager(self): + cond = _async_create_condition(_async_create_lock()) + self.assertFalse(cond.locked()) + async with cond: + self.assertTrue(cond.locked()) + self.assertFalse(cond.locked()) + + async def test_timeout_in_block(self): + condition = _async_create_condition(_async_create_lock()) + async with condition: + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(condition.wait(), timeout=0.5) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_wakeup(self): + # Test that a cancelled error, received when awaiting wakeup, + # will be re-raised un-modified. + wake = False + raised = None + cond = _async_create_condition(_async_create_lock()) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_re_aquire(self): + # Test that a cancelled error, received when re-aquiring lock, + # will be re-raised un-modified. + wake = False + raised = None + cond = _async_create_condition(_async_create_lock()) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition + await cond.acquire() + wake = True + cond.notify() + await asyncio.sleep(0) + # Task is now trying to re-acquire the lock, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + cond.release() + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is awaiting initial + # wakeup on the wakeup queue. + condition = _async_create_condition(_async_create_lock()) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # Cancel it while it is awaiting to be run. + # This cancellation could come from the outside + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup_relock(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is acquiring the lock + # again. + condition = _async_create_condition(_async_create_lock()) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # now we sleep for a bit. This allows the target task to wake up and + # settle on re-aquiring the lock + await asyncio.sleep(0) + + # Cancel it while awaiting the lock + # This cancel could come the outside. + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_monitoring.py b/test/asynchronous/test_monitoring.py index b0c86ab54..eaad60bea 100644 --- a/test/asynchronous/test_monitoring.py +++ b/test/asynchronous/test_monitoring.py @@ -52,22 +52,16 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest): listener: EventListener @classmethod - @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() + def setUpClass(cls) -> None: cls.listener = OvertCommandListener() - cls.client = await cls.unmanaged_async_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False - ) - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - await super()._tearDown_class() - - async def asyncTearDown(self): + @async_client_context.require_connection + async def asyncSetUp(self) -> None: + await super().asyncSetUp() self.listener.reset() - await super().asyncTearDown() + self.client = await self.async_rs_or_single_client( + event_listeners=[self.listener], retryWrites=False + ) async def test_started_simple(self): await self.client.pymongo_test.command("ping") @@ -1140,26 +1134,23 @@ class AsyncTestGlobalListener(AsyncIntegrationTest): saved_listeners: Any @classmethod - @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() + def setUpClass(cls) -> None: cls.listener = OvertCommandListener() # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) - cls.client = await cls.unmanaged_async_single_client() - # Get one (authenticated) socket in the pool. - await cls.client.pymongo_test.command("ping") - - @classmethod - async def _tearDown_class(cls): - monitoring._LISTENERS = cls.saved_listeners - await cls.client.close() - await super()._tearDown_class() + @async_client_context.require_connection async def asyncSetUp(self): await super().asyncSetUp() self.listener.reset() + self.client = await self.async_single_client() + # Get one (authenticated) socket in the pool. + await self.client.pymongo_test.command("ping") + + @classmethod + def tearDownClass(cls): + monitoring._LISTENERS = cls.saved_listeners async def test_simple(self): await self.client.pymongo_test.command("ping") diff --git a/test/asynchronous/test_retryable_writes.py b/test/asynchronous/test_retryable_writes.py index ca2f0a542..738ce0419 100644 --- a/test/asynchronous/test_retryable_writes.py +++ b/test/asynchronous/test_retryable_writes.py @@ -132,34 +132,27 @@ class IgnoreDeprecationsTest(AsyncIntegrationTest): RUN_ON_SERVERLESS = True deprecation_filter: DeprecationFilter - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.deprecation_filter = DeprecationFilter() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.deprecation_filter = DeprecationFilter() - @classmethod - async def _tearDown_class(cls): - cls.deprecation_filter.stop() - await super()._tearDown_class() + async def asyncTearDown(self) -> None: + self.deprecation_filter.stop() class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): knobs: client_knobs - @classmethod - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.client = await cls.unmanaged_async_rs_or_single_client(retryWrites=True) - cls.db = cls.client.pymongo_test + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.client = await self.async_rs_or_single_client(retryWrites=True) + self.db = self.client.pymongo_test - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - await cls.client.close() - await super()._tearDown_class() + async def asyncTearDown(self) -> None: + self.knobs.disable() @async_client_context.require_no_standalone async def test_actionable_error_message(self): @@ -180,26 +173,18 @@ class TestRetryableWrites(IgnoreDeprecationsTest): listener: OvertCommandListener knobs: client_knobs - @classmethod @async_client_context.require_no_mmap - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.listener = OvertCommandListener() - cls.client = await cls.unmanaged_async_rs_or_single_client( - retryWrites=True, event_listeners=[cls.listener] + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.listener = OvertCommandListener() + self.client = await self.async_rs_or_single_client( + retryWrites=True, event_listeners=[self.listener] ) - cls.db = cls.client.pymongo_test + self.db = self.client.pymongo_test - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - await cls.client.close() - await super()._tearDown_class() - - async def asyncSetUp(self): if async_client_context.is_rs and async_client_context.test_commands_enabled: await self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")]) @@ -210,6 +195,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest): await self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) ) + self.knobs.disable() async def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() @@ -438,13 +424,12 @@ class TestWriteConcernError(AsyncIntegrationTest): RUN_ON_SERVERLESS = True fail_insert: dict - @classmethod @async_client_context.require_replica_set @async_client_context.require_no_mmap @async_client_context.require_failCommand_fail_point - async def _setup_class(cls): - await super()._setup_class() - cls.fail_insert = { + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.fail_insert = { "configureFailPoint": "failCommand", "mode": {"times": 2}, "data": { diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index b43262179..42bc253b5 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -38,7 +38,6 @@ from test.utils import ( ExceptionCatchingThread, OvertCommandListener, async_wait_until, - wait_until, ) from bson import DBRef @@ -83,36 +82,27 @@ class TestSession(AsyncIntegrationTest): client2: AsyncMongoClient sensitive_commands: Set[str] - @classmethod @async_client_context.require_sessions - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = await cls.unmanaged_async_rs_or_single_client() + self.client2 = await self.async_rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". - cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() + self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() monitoring._SENSITIVE_COMMANDS.clear() - @classmethod - async def _tearDown_class(cls): - monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands) - await cls.client2.close() - await super()._tearDown_class() - - async def asyncSetUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() self.client = await self.async_rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) - self.addAsyncCleanup(self.client.close) self.db = self.client.pymongo_test self.initial_lsids = {s["id"] for s in session_ids(self.client)} async def asyncTearDown(self): - """All sessions used in the test must be returned to the pool.""" + monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands) await self.client.drop_database("pymongo_test") used_lsids = self.initial_lsids.copy() for event in self.session_checker_listener.started_events: @@ -122,6 +112,8 @@ class TestSession(AsyncIntegrationTest): current_lsids = {s["id"] for s in session_ids(self.client)} self.assertLessEqual(used_lsids, current_lsids) + await super().asyncTearDown() + async def _test_ops(self, client, *ops): listener = client.options.event_listeners[0] @@ -833,18 +825,11 @@ class TestCausalConsistency(AsyncUnitTest): listener: SessionTestListener client: AsyncMongoClient - @classmethod - async def _setup_class(cls): - cls.listener = SessionTestListener() - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - @async_client_context.require_sessions async def asyncSetUp(self): await super().asyncSetUp() + self.listener = SessionTestListener() + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) @async_client_context.require_no_standalone async def test_core(self): diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index b5d068641..d11d0a977 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -26,7 +26,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.utils import ( OvertCommandListener, - wait_until, + async_wait_until, ) from typing import List @@ -162,7 +162,7 @@ class TestTransactions(AsyncTransactionsBase): client = await self.async_rs_client( async_client_context.mongos_seeds(), localThresholdMS=1000 ) - wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") + await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. await coll.insert_one({}) @@ -191,7 +191,7 @@ class TestTransactions(AsyncTransactionsBase): client = await self.async_rs_client( async_client_context.mongos_seeds(), localThresholdMS=1000 ) - wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") + await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. await coll.insert_one({}) @@ -403,21 +403,12 @@ class PatchSessionTimeout: class TestTransactionsConvenientAPI(AsyncTransactionsBase): - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.mongos_clients = [] + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.mongos_clients = [] if async_client_context.supports_transactions(): for address in async_client_context.mongoses: - cls.mongos_clients.append( - await cls.unmanaged_async_single_client("{}:{}".format(*address)) - ) - - @classmethod - async def _tearDown_class(cls): - for client in cls.mongos_clients: - await client.close() - await super()._tearDown_class() + self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address))) async def _set_fail_point(self, client, command_args): cmd = {"configureFailPoint": "failCommand"} diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index db5ed81e2..b18b09383 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -50,6 +50,7 @@ from test.unified_format_shared import ( ) from test.utils import ( async_get_pool, + async_wait_until, camel_to_snake, camel_to_snake_args, parse_spec_options, @@ -304,7 +305,6 @@ class EntityMapUtil: kwargs["h"] = uri client = await self.test.async_rs_or_single_client(**kwargs) self[spec["id"]] = client - self.test.addAsyncCleanup(client.close) return elif entity_type == "database": client = self[spec["client"]] @@ -479,33 +479,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest): await db.create_collection(coll_name, write_concern=wc, **opts) @classmethod - async def _setup_class(cls): - # super call creates internal client cls.client - await super()._setup_class() - # process file-level runOnRequirements - run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) - if not await cls.should_run_on(run_on_spec): - raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied") - - # add any special-casing for skipping tests here - if async_client_context.storage_engine == "mmapv1": - if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str( - cls.TEST_PATH - ): - raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") - - # Handle mongos_clients for transactions tests. - cls.mongos_clients = [] - if ( - async_client_context.supports_transactions() - and not async_client_context.load_balancer - and not async_client_context.serverless - ): - for address in async_client_context.mongoses: - cls.mongos_clients.append( - await cls.unmanaged_async_single_client("{}:{}".format(*address)) - ) - + def setUpClass(cls) -> None: # Speed up the tests by decreasing the heartbeat frequency. cls.knobs = client_knobs( heartbeat_frequency=0.1, @@ -516,17 +490,36 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest): cls.knobs.enable() @classmethod - async def _tearDown_class(cls): + def tearDownClass(cls) -> None: cls.knobs.disable() - for client in cls.mongos_clients: - await client.close() - await super()._tearDown_class() async def asyncSetUp(self): + # super call creates internal client cls.client await super().asyncSetUp() + # process file-level runOnRequirements + run_on_spec = self.TEST_SPEC.get("runOnRequirements", []) + if not await self.should_run_on(run_on_spec): + raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied") + + # add any special-casing for skipping tests here + if async_client_context.storage_engine == "mmapv1": + if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str( + self.TEST_PATH + ): + raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") + + # Handle mongos_clients for transactions tests. + self.mongos_clients = [] + if ( + async_client_context.supports_transactions() + and not async_client_context.load_balancer + and not async_client_context.serverless + ): + for address in async_client_context.mongoses: + self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address))) + # process schemaVersion # note: we check major schema version during class generation - # note: we do this here because we cannot run assertions in setUpClass version = Version.from_string(self.TEST_SPEC["schemaVersion"]) self.assertLessEqual( version, @@ -1036,7 +1029,6 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest): ) client = await self.async_single_client("{}:{}".format(*session._pinned_address)) - self.addAsyncCleanup(client.close) await self.__set_fail_point(client=client, command_args=spec["failPoint"]) async def _testOperation_createEntities(self, spec): @@ -1137,13 +1129,13 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest): client, event, count = spec["client"], spec["event"], spec["count"] self.assertEqual(self._event_count(client, event), count, f"expected {count} not {event!r}") - def _testOperation_waitForEvent(self, spec): + async def _testOperation_waitForEvent(self, spec): """Run the waitForEvent test operation. Wait for a number of events to be published, or fail. """ client, event, count = spec["client"], spec["event"], spec["count"] - wait_until( + await async_wait_until( lambda: self._event_count(client, event) >= count, f"find {count} {event} event(s)", ) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index f27f52ec2..b79e5258b 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -249,30 +249,22 @@ class AsyncSpecRunner(AsyncIntegrationTest): knobs: client_knobs listener: EventListener - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.mongos_clients = [] + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.mongos_clients = [] # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - for client in cls.mongos_clients: - await client.close() - await super()._tearDown_class() - - def setUp(self): - super().setUp() + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() self.targets = {} self.listener = None # type: ignore self.pool_listener = None self.server_listener = None self.maxDiff = None + async def asyncTearDown(self) -> None: + self.knobs.disable() + async def _set_fail_point(self, client, command_args): cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args) @@ -700,8 +692,6 @@ class AsyncSpecRunner(AsyncIntegrationTest): self.listener = listener self.pool_listener = pool_listener self.server_listener = server_listener - # Close the client explicitly to avoid having too many threads open. - self.addAsyncCleanup(client.close) # Create session0 and session1. sessions = {} diff --git a/test/conftest.py b/test/conftest.py index a3d954c7c..91fad28d0 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -20,7 +20,7 @@ def event_loop_policy(): return asyncio.get_event_loop_policy() -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="package", autouse=True) def test_setup_and_teardown(): setup() yield diff --git a/test/test_bulk.py b/test/test_bulk.py index ea2b80380..6d29ff510 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -42,15 +42,11 @@ class BulkTestBase(IntegrationTest): coll: Collection coll_w0: Collection - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.coll = cls.db.test - cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0)) - def setUp(self): super().setUp() + self.coll = self.db.test self.coll.drop() + self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0)) def assertEqualResponse(self, expected, actual): """Compare response from bulk.execute() to expected response.""" @@ -785,12 +781,8 @@ class TestBulk(BulkTestBase): class BulkAuthorizationTestBase(BulkTestBase): - @classmethod @client_context.require_auth @client_context.require_no_api_version - def _setup_class(cls): - super()._setup_class() - def setUp(self): super().setUp() client_context.create_user(self.db.name, "readonly", "pw", ["read"]) @@ -935,21 +927,19 @@ class TestBulkWriteConcern(BulkTestBase): w: Optional[int] secondary: MongoClient - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.w = client_context.w - cls.secondary = None - if cls.w is not None and cls.w > 1: + def setUp(self): + super().setUp() + self.w = client_context.w + self.secondary = None + if self.w is not None and self.w > 1: for member in (client_context.hello)["hosts"]: if member != (client_context.hello)["primary"]: - cls.secondary = cls.unmanaged_single_client(*partition_node(member)) + self.secondary = self.single_client(*partition_node(member)) break - @classmethod - def async_tearDownClass(cls): - if cls.secondary: - cls.secondary.close() + def tearDown(self): + if self.secondary: + self.secondary.close() def cause_wtimeout(self, requests, ordered): if not client_context.test_commands_enabled: diff --git a/test/test_change_stream.py b/test/test_change_stream.py index 3a107122b..4ed21f55c 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -820,18 +820,16 @@ class ProseSpecTestsMixin: class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin): dbs: list - @classmethod @client_context.require_version_min(4, 0, 0, -1) @client_context.require_change_streams - def _setup_class(cls): - super()._setup_class() - cls.dbs = [cls.db, cls.client.pymongo_test_2] + def setUp(self) -> None: + super().setUp() + self.dbs = [self.db, self.client.pymongo_test_2] - @classmethod - def _tearDown_class(cls): - for db in cls.dbs: - cls.client.drop_database(db) - super()._tearDown_class() + def tearDown(self): + for db in self.dbs: + self.client.drop_database(db) + super().tearDown() def change_stream_with_client(self, client, *args, **kwargs): return client.watch(*args, **kwargs) @@ -882,11 +880,10 @@ class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin): class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin): - @classmethod @client_context.require_version_min(4, 0, 0, -1) @client_context.require_change_streams - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() def change_stream_with_client(self, client, *args, **kwargs): return client[self.db.name].watch(*args, **kwargs) @@ -968,12 +965,9 @@ class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin): class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin): - @classmethod @client_context.require_change_streams - def _setup_class(cls): - super()._setup_class() - def setUp(self): + super().setUp() # Use a new collection for each test. self.watched_collection().drop() self.watched_collection().insert_one({}) @@ -1111,20 +1105,11 @@ class TestAllLegacyScenarios(IntegrationTest): RUN_ON_LOAD_BALANCER = True listener: AllowListEventListener - @classmethod @client_context.require_connection - def _setup_class(cls): - super()._setup_class() - cls.listener = AllowListEventListener("aggregate", "getMore") - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - def _tearDown_class(cls): - cls.client.close() - super()._tearDown_class() - def setUp(self): super().setUp() + self.listener = AllowListEventListener("aggregate", "getMore") + self.client = self.rs_or_single_client(event_listeners=[self.listener]) self.listener.reset() def setUpCluster(self, scenario_dict): diff --git a/test/test_client.py b/test/test_client.py index 5bbb5bd75..5ec425f31 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -129,13 +129,8 @@ class ClientUnitTest(UnitTest): client: MongoClient - @classmethod - def _setup_class(cls): - cls.client = cls.unmanaged_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) - - @classmethod - def _tearDown_class(cls): - cls.client.close() + def setUp(self) -> None: + self.client = self.rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): @@ -1039,14 +1034,21 @@ class TestClient(IntegrationTest): self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. - kc_thread = client._kill_cursors_executor._thread - self.assertFalse(kc_thread and kc_thread.is_alive()) - + if _IS_SYNC: + kc_thread = client._kill_cursors_executor._thread + self.assertFalse(kc_thread and kc_thread.is_alive()) + else: + kc_task = client._kill_cursors_executor._task + self.assertFalse(kc_task and not kc_task.done()) # Using the client should open topology and start the thread. client.admin.command("ping") self.assertTrue(client._topology._opened) - kc_thread = client._kill_cursors_executor._thread - self.assertTrue(kc_thread and kc_thread.is_alive()) + if _IS_SYNC: + kc_thread = client._kill_cursors_executor._thread + self.assertTrue(kc_thread and kc_thread.is_alive()) + else: + kc_task = client._kill_cursors_executor._task + self.assertTrue(kc_task and not kc_task.done()) def test_close_does_not_open_servers(self): client = self.rs_client(connect=False) @@ -1241,6 +1243,7 @@ class TestClient(IntegrationTest): def test_server_selection_timeout(self): client = MongoClient(serverSelectionTimeoutMS=100, connect=False) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + client.close() client = MongoClient(serverSelectionTimeoutMS=0, connect=False) @@ -1251,16 +1254,20 @@ class TestClient(IntegrationTest): self.assertRaises( ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False ) + client.close() client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + client.close() client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) self.assertAlmostEqual(0, client.options.server_selection_timeout) + client.close() # Test invalid timeout in URI ignored and set to default. client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) self.assertAlmostEqual(30, client.options.server_selection_timeout) + client.close() client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) self.assertAlmostEqual(30, client.options.server_selection_timeout) diff --git a/test/test_collation.py b/test/test_collation.py index b878df2fb..06436f063 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -97,26 +97,19 @@ class TestCollation(IntegrationTest): warn_context: Any collation: Collation - @classmethod @client_context.require_connection - def _setup_class(cls): - super()._setup_class() - cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - cls.db = cls.client.pymongo_test - cls.collation = Collation("en_US") - cls.warn_context = warnings.catch_warnings() - cls.warn_context.__enter__() - warnings.simplefilter("ignore", DeprecationWarning) + def setUp(self) -> None: + super().setUp() + self.listener = OvertCommandListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) + self.db = self.client.pymongo_test + self.collation = Collation("en_US") + self.warn_context = warnings.catch_warnings() + self.warn_context.__enter__() - @classmethod - def _tearDown_class(cls): - cls.warn_context.__exit__() - cls.warn_context = None - cls.client.close() - super()._tearDown_class() - - def tearDown(self): + def tearDown(self) -> None: + self.warn_context.__exit__() + self.warn_context = None self.listener.reset() super().tearDown() diff --git a/test/test_collection.py b/test/test_collection.py index 84a900d45..af524bba4 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -87,14 +87,10 @@ class TestCollectionNoConnect(UnitTest): db: Database client: MongoClient - @classmethod - def _setup_class(cls): - cls.client = MongoClient(connect=False) - cls.db = cls.client.pymongo_test - - @classmethod - def _tearDown_class(cls): - cls.client.close() + def setUp(self) -> None: + super().setUp() + self.client = self.simple_client(connect=False) + self.db = self.client.pymongo_test def test_collection(self): self.assertRaises(TypeError, Collection, self.db, 5) @@ -164,27 +160,14 @@ class TestCollectionNoConnect(UnitTest): class TestCollection(IntegrationTest): w: int - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.w = client_context.w # type: ignore - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine] - else: - asyncio.run(cls.async_tearDownClass()) - - @classmethod - def async_tearDownClass(cls): - cls.db.drop_collection("test_large_limit") - def setUp(self): - self.db.test.drop() + super().setUp() + self.w = client_context.w # type: ignore def tearDown(self): self.db.test.drop() + self.db.drop_collection("test_large_limit") + super().tearDown() @contextlib.contextmanager def write_concern_collection(self): @@ -1010,7 +993,10 @@ class TestCollection(IntegrationTest): db.test.insert_one({"y": 1}, bypass_document_validation=True) db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True) - wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document") + def predicate(): + return db_w0.test.find_one({"x": 1}) + + wait_until(predicate, "find w:0 replaced document") def test_update_bypass_document_validation(self): db = self.db diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 54cc4e048..84ef6decd 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -19,7 +19,12 @@ import sys sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import ( + IntegrationTest, + client_context, + reset_client_context, + unittest, +) from test.helpers import repl_set_step_down from test.utils import ( CMAPListener, @@ -39,29 +44,19 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): listener: CMAPListener coll: Collection - @classmethod @client_context.require_replica_set - def _setup_class(cls): - super()._setup_class() - cls.listener = CMAPListener() - cls.client = cls.unmanaged_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 + def setUp(self): + self.listener = CMAPListener() + self.client = self.rs_or_single_client( + event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500 ) # Ensure connections to all servers in replica set. This is to test # that the is_writable flag is properly updated for connections that # survive a replica set election. - ensure_all_connected(cls.client) - cls.listener.reset() - - cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority")) - cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) - - @classmethod - def _tearDown_class(cls): - cls.client.close() - - def setUp(self): + ensure_all_connected(self.client) + self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority")) + self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority")) # Note that all ops use same write-concern as self.db (majority). self.db.drop_collection("step-down") self.db.create_collection("step-down") diff --git a/test/test_create_entities.py b/test/test_create_entities.py index ad75fe570..9d77a08ee 100644 --- a/test/test_create_entities.py +++ b/test/test_create_entities.py @@ -56,6 +56,9 @@ class TestCreateEntities(IntegrationTest): self.assertGreater(len(final_entity_map["events1"]), 0) for event in final_entity_map["events1"]: self.assertIn("PoolCreatedEvent", event["name"]) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + client.close() def test_store_all_others_as_entities(self): self.scenario_runner = UnifiedSpecTestMixinV1() @@ -122,6 +125,9 @@ class TestCreateEntities(IntegrationTest): self.assertEqual(entity_map["failures"], []) self.assertEqual(entity_map["successes"], 2) self.assertEqual(entity_map["iterations"], 5) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + client.close() if __name__ == "__main__": diff --git a/test/test_cursor.py b/test/test_cursor.py index 9eac0f1c4..bcc7ed75f 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1636,10 +1636,6 @@ class TestRawBatchCursor(IntegrationTest): class TestRawBatchCommandCursor(IntegrationTest): - @classmethod - def _setup_class(cls): - super()._setup_class() - def test_aggregate_raw(self): c = self.db.test c.drop() diff --git a/test/test_custom_types.py b/test/test_custom_types.py index abaa820cb..6771ea25f 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -633,6 +633,7 @@ class TestTypeRegistry(unittest.TestCase): class TestCollectionWCustomType(IntegrationTest): def setUp(self): + super().setUp() self.db.test.drop() def tearDown(self): @@ -754,6 +755,7 @@ class TestCollectionWCustomType(IntegrationTest): class TestGridFileCustomType(IntegrationTest): def setUp(self): + super().setUp() self.db.drop_collection("fs.files") self.db.drop_collection("fs.chunks") @@ -917,11 +919,10 @@ class ChangeStreamsWCustomTypesTestMixin: class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): - @classmethod @client_context.require_change_streams - def setUpClass(cls): - super().setUpClass() - cls.db.test.delete_many({}) + def setUp(self): + super().setUp() + self.db.test.delete_many({}) def tearDown(self): self.input_target.drop() @@ -935,12 +936,11 @@ class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCus class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): - @classmethod @client_context.require_version_min(4, 0, 0) @client_context.require_change_streams - def setUpClass(cls): - super().setUpClass() - cls.db.test.delete_many({}) + def setUp(self): + super().setUp() + self.db.test.delete_many({}) def tearDown(self): self.input_target.drop() @@ -954,12 +954,11 @@ class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCusto class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): - @classmethod @client_context.require_version_min(4, 0, 0) @client_context.require_change_streams - def setUpClass(cls): - super().setUpClass() - cls.db.test.delete_many({}) + def setUp(self): + super().setUp() + self.db.test.delete_many({}) def tearDown(self): self.input_target.drop() diff --git a/test/test_database.py b/test/test_database.py index 4973ed013..5e854c941 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -709,6 +709,7 @@ class TestDatabase(IntegrationTest): class TestDatabaseAggregation(IntegrationTest): def setUp(self): + super().setUp() self.pipeline: List[Mapping[str, Any]] = [ {"$listLocalSessions": {}}, {"$limit": 1}, diff --git a/test/test_encryption.py b/test/test_encryption.py index 0806f91a0..cb8bcb74d 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -211,11 +211,10 @@ class TestClientOptions(PyMongoTestCase): class EncryptionIntegrationTest(IntegrationTest): """Base class for encryption integration tests.""" - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @client_context.require_version_min(4, 2, -1) - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() def assertEncrypted(self, val): self.assertIsInstance(val, Binary) @@ -430,10 +429,9 @@ class TestEncryptedBulkWrite(BulkTestBase, EncryptionIntegrationTest): class TestClientMaxWireVersion(IntegrationTest): - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() @client_context.require_version_max(4, 0, 99) def test_raise_max_wire_version_error(self): @@ -816,17 +814,16 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest): "local": None, } - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - def _setup_class(cls): - super()._setup_class() - cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - cls.client.db.coll.drop() - cls.vault = create_key_vault(cls.client.keyvault.datakeys) + def setUp(self): + super().setUp() + self.listener = OvertCommandListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) + self.client.db.coll.drop() + self.vault = create_key_vault(self.client.keyvault.datakeys) # Configure the encrypted field via the local schema_map option. schemas = { @@ -844,25 +841,22 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest): } } opts = AutoEncryptionOpts( - cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS + self.KMS_PROVIDERS, + "keyvault.datakeys", + schema_map=schemas, + kms_tls_options=KMS_TLS_OPTS, ) - cls.client_encrypted = cls.unmanaged_rs_or_single_client( + self.client_encrypted = self.rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - cls.client_encryption = cls.unmanaged_create_client_encryption( - cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS + self.client_encryption = self.create_client_encryption( + self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS ) - - @classmethod - def _tearDown_class(cls): - cls.vault.drop() - cls.client.close() - cls.client_encrypted.close() - cls.client_encryption.close() - - def setUp(self): self.listener.reset() + def tearDown(self) -> None: + self.vault.drop() + def run_test(self, provider_name): # Create data key. master_key: Any = self.MASTER_KEYS[provider_name] @@ -1007,10 +1001,9 @@ class TestViews(EncryptionIntegrationTest): class TestCorpus(EncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() @staticmethod def kms_providers(): @@ -1184,12 +1177,11 @@ class TestBsonSizeBatches(EncryptionIntegrationTest): client_encrypted: MongoClient listener: OvertCommandListener - @classmethod - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() db = client_context.client.db - cls.coll = db.coll - cls.coll.drop() + self.coll = db.coll + self.coll.drop() # Configure the encrypted 'db.coll' collection via jsonSchema. json_schema = json_data("limits", "limits-schema.json") db.create_collection( @@ -1207,17 +1199,14 @@ class TestBsonSizeBatches(EncryptionIntegrationTest): coll.insert_one(json_data("limits", "limits-key.json")) opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") - cls.listener = OvertCommandListener() - cls.client_encrypted = cls.unmanaged_rs_or_single_client( - auto_encryption_opts=opts, event_listeners=[cls.listener] + self.listener = OvertCommandListener() + self.client_encrypted = self.rs_or_single_client( + auto_encryption_opts=opts, event_listeners=[self.listener] ) - cls.coll_encrypted = cls.client_encrypted.db.coll + self.coll_encrypted = self.client_encrypted.db.coll - @classmethod - def _tearDown_class(cls): - cls.coll_encrypted.drop() - cls.client_encrypted.close() - super()._tearDown_class() + def tearDown(self) -> None: + self.coll_encrypted.drop() def test_01_insert_succeeds_under_2MiB(self): doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB} @@ -1241,7 +1230,9 @@ class TestBsonSizeBatches(EncryptionIntegrationTest): doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB} self.listener.reset() self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) - self.assertEqual(self.listener.started_command_names(), ["insert", "insert"]) + self.assertEqual( + len([c for c in self.listener.started_command_names() if c == "insert"]), 2 + ) def test_04_bulk_batch_split(self): limits_doc = json_data("limits", "limits-doc.json") @@ -1251,7 +1242,9 @@ class TestBsonSizeBatches(EncryptionIntegrationTest): doc2.update(limits_doc) self.listener.reset() self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) - self.assertEqual(self.listener.started_command_names(), ["insert", "insert"]) + self.assertEqual( + len([c for c in self.listener.started_command_names() if c == "insert"]), 2 + ) def test_05_insert_succeeds_just_under_16MiB(self): doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)} @@ -1281,15 +1274,12 @@ class TestBsonSizeBatches(EncryptionIntegrationTest): class TestCustomEndpoint(EncryptionIntegrationTest): """Prose tests for creating data keys with a custom endpoint.""" - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - def _setup_class(cls): - super()._setup_class() - def setUp(self): + super().setUp() kms_providers = { "aws": AWS_CREDS, "azure": AZURE_CREDS, @@ -1318,10 +1308,6 @@ class TestCustomEndpoint(EncryptionIntegrationTest): self._kmip_host_error = None self._invalid_host_error = None - def tearDown(self): - self.client_encryption.close() - self.client_encryption_invalid.close() - def run_test_expected_success(self, provider_name, master_key): data_key_id = self.client_encryption.create_data_key(provider_name, master_key=master_key) encrypted = self.client_encryption.encrypt( @@ -1494,18 +1480,18 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest): KEYVAULT_COLL = "datakeys" client: MongoClient - def setUp(self): + def _setup(self): keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL) create_key_vault(keyvault, self.DEK) def _test_explicit(self, expectation): + self._setup() client_encryption = self.create_client_encryption( self.KMS_PROVIDER_MAP, # type: ignore[arg-type] ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), client_context.client, OPTS, ) - self.addCleanup(client_encryption.close) ciphertext = client_encryption.encrypt( "string0", @@ -1517,6 +1503,7 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest): self.assertEqual(client_encryption.decrypt(ciphertext), "string0") def _test_automatic(self, expectation_extjson, payload): + self._setup() encrypted_db = "db" encrypted_coll = "coll" keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]) @@ -1531,7 +1518,6 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest): client = self.rs_or_single_client( auto_encryption_opts=encryption_opts, event_listeners=[insert_listener] ) - self.addCleanup(client.close) coll = client.get_database(encrypted_db).get_collection( encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority") @@ -1553,13 +1539,12 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest): class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set") - def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} - cls.DEK = json_data(BASE, "custom", "azure-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - super()._setup_class() + def setUp(self): + self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} + self.DEK = json_data(BASE, "custom", "azure-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + super().setUp() def test_explicit(self): return self._test_explicit( @@ -1579,13 +1564,12 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set") - def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} - cls.DEK = json_data(BASE, "custom", "gcp-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - super()._setup_class() + def setUp(self): + self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} + self.DEK = json_data(BASE, "custom", "gcp-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + super().setUp() def test_explicit(self): return self._test_explicit( @@ -1607,6 +1591,7 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest): # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests class TestDeadlockProse(EncryptionIntegrationTest): def setUp(self): + super().setUp() self.client_test = self.rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" ) @@ -1637,7 +1622,6 @@ class TestDeadlockProse(EncryptionIntegrationTest): self.ciphertext = client_encryption.encrypt( "string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local" ) - client_encryption.close() self.client_listener = OvertCommandListener() self.topology_listener = TopologyEventListener() @@ -1832,6 +1816,7 @@ class TestDeadlockProse(EncryptionIntegrationTest): # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events class TestDecryptProse(EncryptionIntegrationTest): def setUp(self): + super().setUp() self.client = client_context.client self.client.db.drop_collection("decryption_events") create_key_vault(self.client.keyvault.datakeys) @@ -2267,6 +2252,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest): # https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames class TestUniqueIndexOnKeyAltNamesProse(EncryptionIntegrationTest): def setUp(self): + super().setUp() self.client = client_context.client create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} @@ -2608,8 +2594,6 @@ class TestQueryableEncryptionDocsExample(EncryptionIntegrationTest): assert isinstance(res["encrypted_indexed"], Binary) assert isinstance(res["encrypted_unindexed"], Binary) - client_encryption.close() - # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption class TestRangeQueryProse(EncryptionIntegrationTest): @@ -3071,17 +3055,11 @@ class TestNoSessionsSupport(EncryptionIntegrationTest): mongocryptd_client: MongoClient MONGOCRYPTD_PORT = 27020 - @classmethod @unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") - def _setup_class(cls): - super()._setup_class() - start_mongocryptd(cls.MONGOCRYPTD_PORT) - - @classmethod - def _tearDown_class(cls): - super()._tearDown_class() - def setUp(self) -> None: + super().setUp() + start_mongocryptd(self.MONGOCRYPTD_PORT) + self.listener = OvertCommandListener() self.mongocryptd_client = self.simple_client( f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] diff --git a/test/test_examples.py b/test/test_examples.py index ebf1d784a..7f98226e7 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -33,19 +33,14 @@ from pymongo.write_concern import WriteConcern class TestSampleShellCommands(IntegrationTest): - @classmethod - def setUpClass(cls): - super().setUpClass() - # Run once before any tests run. - cls.db.inventory.drop() - - @classmethod - def tearDownClass(cls): - cls.client.drop_database("pymongo_test") + def setUp(self): + super().setUp() + self.db.inventory.drop() def tearDown(self): # Run after every test. self.db.inventory.drop() + self.client.drop_database("pymongo_test") def test_first_three_examples(self): db = self.db diff --git a/test/test_grid_file.py b/test/test_grid_file.py index c35efccef..6534bc11b 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -97,6 +97,7 @@ class TestGridFileNoConnect(UnitTest): class TestGridFile(IntegrationTest): def setUp(self): + super().setUp() self.cleanup_colls(self.db.fs.files, self.db.fs.chunks) def test_basic(self): diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 549dc0b20..a36109f39 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -75,9 +75,9 @@ class JustRead(threading.Thread): class TestGridfsNoConnect(unittest.TestCase): db: Database - @classmethod - def setUpClass(cls): - cls.db = MongoClient(connect=False).pymongo_test + def setUp(self): + super().setUp() + self.db = MongoClient(connect=False).pymongo_test def test_gridfs(self): self.assertRaises(TypeError, gridfs.GridFS, "foo") @@ -88,13 +88,10 @@ class TestGridfs(IntegrationTest): fs: gridfs.GridFS alt: gridfs.GridFS - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.fs = gridfs.GridFS(cls.db) - cls.alt = gridfs.GridFS(cls.db, "alt") - def setUp(self): + super().setUp() + self.fs = gridfs.GridFS(self.db) + self.alt = gridfs.GridFS(self.db, "alt") self.cleanup_colls( self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks ) @@ -509,10 +506,9 @@ class TestGridfs(IntegrationTest): class TestGridfsReplicaSet(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) - def setUpClass(cls): - super().setUpClass() + def setUp(self): + super().setUp() @classmethod def tearDownClass(cls): diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index 28adb7051..04c742735 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -79,13 +79,10 @@ class TestGridfs(IntegrationTest): fs: gridfs.GridFSBucket alt: gridfs.GridFSBucket - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.fs = gridfs.GridFSBucket(cls.db) - cls.alt = gridfs.GridFSBucket(cls.db, bucket_name="alt") - def setUp(self): + super().setUp() + self.fs = gridfs.GridFSBucket(self.db) + self.alt = gridfs.GridFSBucket(self.db, bucket_name="alt") self.cleanup_colls( self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks ) @@ -479,10 +476,9 @@ class TestGridfs(IntegrationTest): class TestGridfsBucketReplicaSet(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) - def setUpClass(cls): - super().setUpClass() + def setUp(self): + super().setUp() @classmethod def tearDownClass(cls): diff --git a/test/test_monitor.py b/test/test_monitor.py index f8e9443fa..a704f3d8c 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -29,7 +29,7 @@ from test.utils import ( wait_until, ) -from pymongo.synchronous.periodic_executor import _EXECUTORS +from pymongo.periodic_executor import _EXECUTORS def unregistered(ref): diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 75fe5c987..670558c0a 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -52,22 +52,14 @@ class TestCommandMonitoring(IntegrationTest): listener: EventListener @classmethod - @client_context.require_connection - def _setup_class(cls): - super()._setup_class() + def setUpClass(cls) -> None: cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False - ) - @classmethod - def _tearDown_class(cls): - cls.client.close() - super()._tearDown_class() - - def tearDown(self): + @client_context.require_connection + def setUp(self) -> None: + super().setUp() self.listener.reset() - super().tearDown() + self.client = self.rs_or_single_client(event_listeners=[self.listener], retryWrites=False) def test_started_simple(self): self.client.pymongo_test.command("ping") @@ -1140,26 +1132,23 @@ class TestGlobalListener(IntegrationTest): saved_listeners: Any @classmethod - @client_context.require_connection - def _setup_class(cls): - super()._setup_class() + def setUpClass(cls) -> None: cls.listener = OvertCommandListener() # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) - cls.client = cls.unmanaged_single_client() - # Get one (authenticated) socket in the pool. - cls.client.pymongo_test.command("ping") - - @classmethod - def _tearDown_class(cls): - monitoring._LISTENERS = cls.saved_listeners - cls.client.close() - super()._tearDown_class() + @client_context.require_connection def setUp(self): super().setUp() self.listener.reset() + self.client = self.single_client() + # Get one (authenticated) socket in the pool. + self.client.pymongo_test.command("ping") + + @classmethod + def tearDownClass(cls): + monitoring._LISTENERS = cls.saved_listeners def test_simple(self): self.client.pymongo_test.command("ping") diff --git a/test/test_read_concern.py b/test/test_read_concern.py index ea9ce49a3..f7c090142 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -31,24 +31,16 @@ from pymongo.read_concern import ReadConcern class TestReadConcern(IntegrationTest): listener: OvertCommandListener - @classmethod @client_context.require_connection - def setUpClass(cls): - super().setUpClass() - cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - cls.db = cls.client.pymongo_test + def setUp(self): + super().setUp() + self.listener = OvertCommandListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) + self.db = self.client.pymongo_test client_context.client.pymongo_test.create_collection("coll") - @classmethod - def tearDownClass(cls): - cls.client.close() - client_context.client.pymongo_test.drop_collection("coll") - super().tearDownClass() - def tearDown(self): - self.listener.reset() - super().tearDown() + client_context.client.pymongo_test.drop_collection("coll") def test_read_concern(self): rc = ReadConcern() diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 74f3c23e5..07bd1db0b 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -132,34 +132,27 @@ class IgnoreDeprecationsTest(IntegrationTest): RUN_ON_SERVERLESS = True deprecation_filter: DeprecationFilter - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.deprecation_filter = DeprecationFilter() + def setUp(self) -> None: + super().setUp() + self.deprecation_filter = DeprecationFilter() - @classmethod - def _tearDown_class(cls): - cls.deprecation_filter.stop() - super()._tearDown_class() + def tearDown(self) -> None: + self.deprecation_filter.stop() class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): knobs: client_knobs - @classmethod - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.client = cls.unmanaged_rs_or_single_client(retryWrites=True) - cls.db = cls.client.pymongo_test + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.client = self.rs_or_single_client(retryWrites=True) + self.db = self.client.pymongo_test - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - cls.client.close() - super()._tearDown_class() + def tearDown(self) -> None: + self.knobs.disable() @client_context.require_no_standalone def test_actionable_error_message(self): @@ -180,26 +173,16 @@ class TestRetryableWrites(IgnoreDeprecationsTest): listener: OvertCommandListener knobs: client_knobs - @classmethod @client_context.require_no_mmap - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client( - retryWrites=True, event_listeners=[cls.listener] - ) - cls.db = cls.client.pymongo_test + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.listener = OvertCommandListener() + self.client = self.rs_or_single_client(retryWrites=True, event_listeners=[self.listener]) + self.db = self.client.pymongo_test - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - cls.client.close() - super()._tearDown_class() - - def setUp(self): if client_context.is_rs and client_context.test_commands_enabled: self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")]) @@ -210,6 +193,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest): self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) ) + self.knobs.disable() def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() @@ -438,13 +422,12 @@ class TestWriteConcernError(IntegrationTest): RUN_ON_SERVERLESS = True fail_insert: dict - @classmethod @client_context.require_replica_set @client_context.require_no_mmap @client_context.require_failCommand_fail_point - def _setup_class(cls): - super()._setup_class() - cls.fail_insert = { + def setUp(self) -> None: + super().setUp() + self.fail_insert = { "configureFailPoint": "failCommand", "mode": {"times": 2}, "data": { diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index 81b208d51..6b808b159 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -270,7 +270,7 @@ class TestSdamMonitoring(IntegrationTest): @classmethod @client_context.require_failCommand_fail_point def setUpClass(cls): - super().setUpClass() + super().setUp(cls) # Speed up the tests by decreasing the event publish frequency. cls.knobs = client_knobs( events_queue_frequency=0.1, heartbeat_frequency=0.1, min_heartbeat_interval=0.1 diff --git a/test/test_session.py b/test/test_session.py index d0bbb075a..634efa11c 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -82,36 +82,27 @@ class TestSession(IntegrationTest): client2: MongoClient sensitive_commands: Set[str] - @classmethod @client_context.require_sessions - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = cls.unmanaged_rs_or_single_client() + self.client2 = self.rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". - cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() + self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() monitoring._SENSITIVE_COMMANDS.clear() - @classmethod - def _tearDown_class(cls): - monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands) - cls.client2.close() - super()._tearDown_class() - - def setUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() self.client = self.rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) - self.addCleanup(self.client.close) self.db = self.client.pymongo_test self.initial_lsids = {s["id"] for s in session_ids(self.client)} def tearDown(self): - """All sessions used in the test must be returned to the pool.""" + monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands) self.client.drop_database("pymongo_test") used_lsids = self.initial_lsids.copy() for event in self.session_checker_listener.started_events: @@ -121,6 +112,8 @@ class TestSession(IntegrationTest): current_lsids = {s["id"] for s in session_ids(self.client)} self.assertLessEqual(used_lsids, current_lsids) + super().tearDown() + def _test_ops(self, client, *ops): listener = client.options.event_listeners[0] @@ -832,18 +825,11 @@ class TestCausalConsistency(UnitTest): listener: SessionTestListener client: MongoClient - @classmethod - def _setup_class(cls): - cls.listener = SessionTestListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - def _tearDown_class(cls): - cls.client.close() - @client_context.require_sessions def setUp(self): super().setUp() + self.listener = SessionTestListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) @client_context.require_no_standalone def test_core(self): diff --git a/test/test_threads.py b/test/test_threads.py index b3dadbb1a..3e469e28f 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -105,6 +105,7 @@ class Update(threading.Thread): class TestThreads(IntegrationTest): def setUp(self): + super().setUp() self.db = self.client.pymongo_test def test_threading(self): diff --git a/test/test_transactions.py b/test/test_transactions.py index 3cecbe9d3..949b88e60 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -395,19 +395,12 @@ class PatchSessionTimeout: class TestTransactionsConvenientAPI(TransactionsBase): - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.mongos_clients = [] + def setUp(self) -> None: + super().setUp() + self.mongos_clients = [] if client_context.supports_transactions(): for address in client_context.mongoses: - cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) - - @classmethod - def _tearDown_class(cls): - for client in cls.mongos_clients: - client.close() - super()._tearDown_class() + self.mongos_clients.append(self.single_client("{}:{}".format(*address))) def _set_fail_point(self, client, command_args): cmd = {"configureFailPoint": "failCommand"} diff --git a/test/test_typing.py b/test/test_typing.py index 441707616..bfe4d032c 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -114,10 +114,9 @@ class TestMypyFails(unittest.TestCase): class TestPymongo(IntegrationTest): coll: Collection - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.coll = cls.client.test.test + def setUp(self): + super().setUp() + self.coll = self.client.test.test def test_insert_find(self) -> None: doc = {"my": "doc"} diff --git a/test/unified_format.py b/test/unified_format.py index 3489a8ac8..5cb268a29 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -304,7 +304,6 @@ class EntityMapUtil: kwargs["h"] = uri client = self.test.rs_or_single_client(**kwargs) self[spec["id"]] = client - self.test.addCleanup(client.close) return elif entity_type == "database": client = self[spec["client"]] @@ -479,31 +478,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): db.create_collection(coll_name, write_concern=wc, **opts) @classmethod - def _setup_class(cls): - # super call creates internal client cls.client - super()._setup_class() - # process file-level runOnRequirements - run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) - if not cls.should_run_on(run_on_spec): - raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied") - - # add any special-casing for skipping tests here - if client_context.storage_engine == "mmapv1": - if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str( - cls.TEST_PATH - ): - raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") - - # Handle mongos_clients for transactions tests. - cls.mongos_clients = [] - if ( - client_context.supports_transactions() - and not client_context.load_balancer - and not client_context.serverless - ): - for address in client_context.mongoses: - cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) - + def setUpClass(cls) -> None: # Speed up the tests by decreasing the heartbeat frequency. cls.knobs = client_knobs( heartbeat_frequency=0.1, @@ -514,17 +489,36 @@ class UnifiedSpecTestMixinV1(IntegrationTest): cls.knobs.enable() @classmethod - def _tearDown_class(cls): + def tearDownClass(cls) -> None: cls.knobs.disable() - for client in cls.mongos_clients: - client.close() - super()._tearDown_class() def setUp(self): + # super call creates internal client cls.client super().setUp() + # process file-level runOnRequirements + run_on_spec = self.TEST_SPEC.get("runOnRequirements", []) + if not self.should_run_on(run_on_spec): + raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied") + + # add any special-casing for skipping tests here + if client_context.storage_engine == "mmapv1": + if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str( + self.TEST_PATH + ): + raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") + + # Handle mongos_clients for transactions tests. + self.mongos_clients = [] + if ( + client_context.supports_transactions() + and not client_context.load_balancer + and not client_context.serverless + ): + for address in client_context.mongoses: + self.mongos_clients.append(self.single_client("{}:{}".format(*address))) + # process schemaVersion # note: we check major schema version during class generation - # note: we do this here because we cannot run assertions in setUpClass version = Version.from_string(self.TEST_SPEC["schemaVersion"]) self.assertLessEqual( version, @@ -1026,7 +1020,6 @@ class UnifiedSpecTestMixinV1(IntegrationTest): ) client = self.single_client("{}:{}".format(*session._pinned_address)) - self.addCleanup(client.close) self.__set_fail_point(client=client, command_args=spec["failPoint"]) def _testOperation_createEntities(self, spec): diff --git a/test/utils.py b/test/utils.py index 9b326e5d7..69154bc63 100644 --- a/test/utils.py +++ b/test/utils.py @@ -99,6 +99,12 @@ class BaseListener: """Wait for a number of events to be published, or fail.""" wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)") + async def async_wait_for_event(self, event, count): + """Wait for a number of events to be published, or fail.""" + await async_wait_until( + lambda: self.event_count(event) >= count, f"find {count} {event} event(s)" + ) + class CMAPListener(BaseListener, monitoring.ConnectionPoolListener): def connection_created(self, event): @@ -644,7 +650,10 @@ async def async_wait_until(predicate, success_description, timeout=10): start = time.time() interval = min(float(timeout) / 100, 0.1) while True: - retval = await predicate() + if iscoroutinefunction(predicate): + retval = await predicate() + else: + retval = predicate() if retval: return retval diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 8b2679d77..4508502cd 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -249,30 +249,22 @@ class SpecRunner(IntegrationTest): knobs: client_knobs listener: EventListener - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.mongos_clients = [] + def setUp(self) -> None: + super().setUp() + self.mongos_clients = [] # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - for client in cls.mongos_clients: - client.close() - super()._tearDown_class() - - def setUp(self): - super().setUp() + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() self.targets = {} self.listener = None # type: ignore self.pool_listener = None self.server_listener = None self.maxDiff = None + def tearDown(self) -> None: + self.knobs.disable() + def _set_fail_point(self, client, command_args): cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args) @@ -697,8 +689,6 @@ class SpecRunner(IntegrationTest): self.listener = listener self.pool_listener = pool_listener self.server_listener = server_listener - # Close the client explicitly to avoid having too many threads open. - self.addCleanup(client.close) # Create session0 and session1. sessions = {} diff --git a/tools/synchro.py b/tools/synchro.py index 0a7109c6d..47617365f 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -110,6 +110,13 @@ replacements = { "async_set_fail_point": "set_fail_point", "async_ensure_all_connected": "ensure_all_connected", "async_repl_set_step_down": "repl_set_step_down", + "AsyncPeriodicExecutor": "PeriodicExecutor", + "async_wait_for_event": "wait_for_event", + "pymongo_server_monitor_task": "pymongo_server_monitor_thread", + "pymongo_server_rtt_task": "pymongo_server_rtt_thread", + "_async_create_lock": "_create_lock", + "_async_create_condition": "_create_condition", + "_async_cond_wait": "_cond_wait", } docstring_replacements: dict[tuple[str, str], str] = { @@ -130,8 +137,6 @@ docstring_removals: set[str] = { ".. warning:: This API is currently in beta, meaning the classes, methods, and behaviors described within may change before the full release." } -type_replacements = {"_Condition": "threading.Condition"} - import_replacements = {"test.synchronous": "test"} _pymongo_base = "./pymongo/asynchronous/" @@ -234,8 +239,6 @@ def process_files(files: list[str]) -> None: lines = translate_async_sleeps(lines) if file in docstring_translate_files: lines = translate_docstrings(lines) - translate_locks(lines) - translate_types(lines) if file in sync_test_files: translate_imports(lines) f.seek(0) @@ -269,34 +272,6 @@ def translate_coroutine_types(lines: list[str]) -> list[str]: return lines -def translate_locks(lines: list[str]) -> list[str]: - lock_lines = [line for line in lines if "_Lock(" in line] - cond_lines = [line for line in lines if "_Condition(" in line] - for line in lock_lines: - res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line) - if res: - old = res[0] - index = lines.index(line) - lines[index] = line.replace(old, res[1]) - for line in cond_lines: - res = re.search(r"_Condition\(([^()]*\([^()]*\))\)", line) - if res: - old = res[0] - index = lines.index(line) - lines[index] = line.replace(old, res[1]) - - return lines - - -def translate_types(lines: list[str]) -> list[str]: - for k, v in type_replacements.items(): - matches = [line for line in lines if k in line and "import" not in line] - for line in matches: - index = lines.index(line) - lines[index] = line.replace(k, v) - return lines - - def translate_imports(lines: list[str]) -> list[str]: for k, v in import_replacements.items(): matches = [line for line in lines if k in line and "import" in line] From cbeebd01901467a204687a70ca498dad70539502 Mon Sep 17 00:00:00 2001 From: theRealProHacker <77074862+theRealProHacker@users.noreply.github.com> Date: Mon, 2 Dec 2024 17:54:56 +0100 Subject: [PATCH 6/9] Small doc fix (#2021) Co-authored-by: Steven Silvester --- pymongo/asynchronous/cursor.py | 2 +- pymongo/synchronous/cursor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 7d7ae4a5d..8193e5328 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -1299,7 +1299,7 @@ class AsyncCursor(Generic[_DocumentType]): >>> await cursor.to_list() - Or, so read at most n items from the cursor:: + Or, to read at most n items from the cursor:: >>> await cursor.to_list(n) diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 9a7637704..b35098a32 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -1297,7 +1297,7 @@ class Cursor(Generic[_DocumentType]): >>> cursor.to_list() - Or, so read at most n items from the cursor:: + Or, to read at most n items from the cursor:: >>> cursor.to_list(n) From bc66598623c46e072518f6f11347096476b100c1 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 2 Dec 2024 12:17:52 -0500 Subject: [PATCH 7/9] PYTHON-4965 - Consolidate startup and teardown tasks (#2017) --- .evergreen/combine-coverage.sh | 0 .evergreen/config.yml | 121 +++++++----------- .evergreen/hatch.sh | 0 .evergreen/install-dependencies.sh | 0 .evergreen/run-azurekms-fail-test.sh | 0 .evergreen/run-azurekms-test.sh | 0 .evergreen/run-deployed-lambda-aws-tests.sh | 0 .evergreen/run-gcpkms-test.sh | 0 .evergreen/run-perf-tests.sh | 0 .evergreen/scripts/archive-mongodb-logs.sh | 0 .../scripts/bootstrap-mongo-orchestration.sh | 0 .evergreen/scripts/check-import-time.sh | 0 .evergreen/scripts/cleanup.sh | 0 .evergreen/scripts/configure-env.sh | 0 .../scripts/download-and-merge-coverage.sh | 0 .evergreen/scripts/fix-absolute-paths.sh | 0 .evergreen/scripts/init-test-results.sh | 0 .evergreen/scripts/install-dependencies.sh | 0 .evergreen/scripts/make-files-executable.sh | 0 .evergreen/scripts/prepare-resources.sh | 0 .evergreen/scripts/run-atlas-tests.sh | 0 .evergreen/scripts/run-aws-ecs-auth-test.sh | 0 .evergreen/scripts/run-doctests.sh | 0 .../scripts/run-enterprise-auth-tests.sh | 0 .evergreen/scripts/run-gcpkms-fail-test.sh | 0 .evergreen/scripts/run-getdata.sh | 0 .evergreen/scripts/run-load-balancer.sh | 0 .evergreen/scripts/run-mockupdb-tests.sh | 0 .evergreen/scripts/run-mod-wsgi-tests.sh | 0 .evergreen/scripts/run-ocsp-test.sh | 0 .evergreen/scripts/run-perf-tests.sh | 0 .evergreen/scripts/run-tests.sh | 0 .evergreen/scripts/run-with-env.sh | 0 .evergreen/scripts/setup-encryption.sh | 0 .evergreen/scripts/setup-tests.sh | 0 .evergreen/scripts/stop-load-balancer.sh | 0 .evergreen/scripts/teardown-aws.sh | 7 - .evergreen/scripts/teardown-docker.sh | 7 - .evergreen/scripts/upload-coverage-report.sh | 0 .evergreen/scripts/windows-fix.sh | 0 .evergreen/setup-encryption.sh | 0 .evergreen/teardown-encryption.sh | 0 .pre-commit-config.yaml | 12 ++ tools/synchro.sh | 0 44 files changed, 55 insertions(+), 92 deletions(-) mode change 100644 => 100755 .evergreen/combine-coverage.sh mode change 100644 => 100755 .evergreen/hatch.sh mode change 100644 => 100755 .evergreen/install-dependencies.sh mode change 100644 => 100755 .evergreen/run-azurekms-fail-test.sh mode change 100644 => 100755 .evergreen/run-azurekms-test.sh mode change 100644 => 100755 .evergreen/run-deployed-lambda-aws-tests.sh mode change 100644 => 100755 .evergreen/run-gcpkms-test.sh mode change 100644 => 100755 .evergreen/run-perf-tests.sh mode change 100644 => 100755 .evergreen/scripts/archive-mongodb-logs.sh mode change 100644 => 100755 .evergreen/scripts/bootstrap-mongo-orchestration.sh mode change 100644 => 100755 .evergreen/scripts/check-import-time.sh mode change 100644 => 100755 .evergreen/scripts/cleanup.sh mode change 100644 => 100755 .evergreen/scripts/configure-env.sh mode change 100644 => 100755 .evergreen/scripts/download-and-merge-coverage.sh mode change 100644 => 100755 .evergreen/scripts/fix-absolute-paths.sh mode change 100644 => 100755 .evergreen/scripts/init-test-results.sh mode change 100644 => 100755 .evergreen/scripts/install-dependencies.sh mode change 100644 => 100755 .evergreen/scripts/make-files-executable.sh mode change 100644 => 100755 .evergreen/scripts/prepare-resources.sh mode change 100644 => 100755 .evergreen/scripts/run-atlas-tests.sh mode change 100644 => 100755 .evergreen/scripts/run-aws-ecs-auth-test.sh mode change 100644 => 100755 .evergreen/scripts/run-doctests.sh mode change 100644 => 100755 .evergreen/scripts/run-enterprise-auth-tests.sh mode change 100644 => 100755 .evergreen/scripts/run-gcpkms-fail-test.sh mode change 100644 => 100755 .evergreen/scripts/run-getdata.sh mode change 100644 => 100755 .evergreen/scripts/run-load-balancer.sh mode change 100644 => 100755 .evergreen/scripts/run-mockupdb-tests.sh mode change 100644 => 100755 .evergreen/scripts/run-mod-wsgi-tests.sh mode change 100644 => 100755 .evergreen/scripts/run-ocsp-test.sh mode change 100644 => 100755 .evergreen/scripts/run-perf-tests.sh mode change 100644 => 100755 .evergreen/scripts/run-tests.sh mode change 100644 => 100755 .evergreen/scripts/run-with-env.sh mode change 100644 => 100755 .evergreen/scripts/setup-encryption.sh mode change 100644 => 100755 .evergreen/scripts/setup-tests.sh mode change 100644 => 100755 .evergreen/scripts/stop-load-balancer.sh delete mode 100644 .evergreen/scripts/teardown-aws.sh delete mode 100644 .evergreen/scripts/teardown-docker.sh mode change 100644 => 100755 .evergreen/scripts/upload-coverage-report.sh mode change 100644 => 100755 .evergreen/scripts/windows-fix.sh mode change 100644 => 100755 .evergreen/setup-encryption.sh mode change 100644 => 100755 .evergreen/teardown-encryption.sh mode change 100644 => 100755 tools/synchro.sh diff --git a/.evergreen/combine-coverage.sh b/.evergreen/combine-coverage.sh old mode 100644 new mode 100755 diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 7ca3a72b1..ac89270d8 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -37,6 +37,8 @@ functions: # Applies the subitted patch, if any # Deprecated. Should be removed. But still needed for certain agents (ZAP) - command: git.apply_patch + + "setup system": # Make an evergreen expansion file with dynamic values - command: subprocess.exec params: @@ -49,13 +51,19 @@ functions: - command: expansions.update params: file: src/expansion.yml - - "prepare resources": - command: subprocess.exec params: + include_expansions_in_env: ["PROJECT_DIRECTORY", "DRIVERS_TOOLS"] binary: bash args: - src/.evergreen/scripts/prepare-resources.sh + # Run drivers-evergreen-tools system setup + - command: subprocess.exec + params: + include_expansions_in_env: ["PROJECT_DIRECTORY", "DRIVERS_TOOLS"] + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/setup.sh "upload coverage" : - command: ec2.assume_role @@ -511,7 +519,26 @@ functions: - .evergreen/scripts/run-with-env.sh - .evergreen/scripts/cleanup.sh - "teardown": + "teardown system": + - command: subprocess.exec + params: + binary: bash + working_dir: "src" + args: + # Ensure the instance profile is reassigned for aws tests. + - ${DRIVERS_TOOLS}/.evergreen/auth_aws/teardown.sh + - command: subprocess.exec + params: + binary: bash + working_dir: "src" + args: + - ${DRIVERS_TOOLS}/.evergreen/csfle/teardown.sh + - command: subprocess.exec + params: + binary: bash + working_dir: "src" + args: + - ${DRIVERS_TOOLS}/.evergreen/ocsp/teardown.sh - command: subprocess.exec params: binary: bash @@ -519,34 +546,6 @@ functions: args: - ${DRIVERS_TOOLS}/.evergreen/teardown.sh - "fix absolute paths": - - command: subprocess.exec - params: - binary: bash - args: - - src/.evergreen/scripts/fix-absolute-paths.sh - - "windows fix": - - command: subprocess.exec - params: - binary: bash - args: - - src/.evergreen/scripts/windows-fix.sh - - "make files executable": - - command: subprocess.exec - params: - binary: bash - args: - - src/.evergreen/scripts/make-files-executable.sh - - "init test-results": - - command: subprocess.exec - params: - binary: bash - args: - - src/.evergreen/scripts/init-test-results.sh - "install dependencies": - command: subprocess.exec params: @@ -621,21 +620,6 @@ functions: args: - src/.evergreen/scripts/stop-load-balancer.sh - "teardown_docker": - - command: subprocess.exec - params: - binary: bash - args: - - src/.evergreen/scripts/teardown-docker.sh - - "teardown_aws": - - command: subprocess.exec - params: - binary: bash - args: - - src/.evergreen/scripts/run-with-env.sh - - src/.evergreen/scripts/teardown-aws.sh - "teardown atlas": - command: subprocess.exec params: @@ -665,25 +649,19 @@ functions: pre: - func: "fetch source" - - func: "prepare resources" - - func: "windows fix" - - func: "fix absolute paths" - - func: "init test-results" - - func: "make files executable" + - func: "setup system" - func: "install dependencies" - func: "assume ec2 role" post: # Disabled, causing timeouts # - func: "upload working dir" - - func: "teardown" + - func: "teardown system" - func: "upload coverage" - func: "upload mo artifacts" - func: "upload test results" - func: "stop mongo-orchestration" - - func: "teardown_aws" - func: "cleanup" - - func: "teardown_docker" task_groups: - name: serverless_task_group @@ -691,7 +669,7 @@ task_groups: setup_group_timeout_secs: 1800 # 30 minutes setup_group: - func: "fetch source" - - func: "prepare resources" + - func: "setup system" - command: subprocess.exec params: binary: bash @@ -714,9 +692,7 @@ task_groups: setup_group_timeout_secs: 1800 # 30 minutes setup_group: - func: fetch source - - func: prepare resources - - func: fix absolute paths - - func: make files executable + - func: setup system - command: subprocess.exec params: binary: bash @@ -735,9 +711,7 @@ task_groups: - name: testazurekms_task_group setup_group: - func: fetch source - - func: prepare resources - - func: fix absolute paths - - func: make files executable + - func: setup system - command: subprocess.exec params: binary: bash @@ -761,9 +735,7 @@ task_groups: - name: testazureoidc_task_group setup_group: - func: fetch source - - func: prepare resources - - func: fix absolute paths - - func: make files executable + - func: setup system - command: subprocess.exec params: binary: bash @@ -785,9 +757,7 @@ task_groups: - name: testgcpoidc_task_group setup_group: - func: fetch source - - func: prepare resources - - func: fix absolute paths - - func: make files executable + - func: setup system - command: subprocess.exec params: binary: bash @@ -809,9 +779,7 @@ task_groups: - name: testk8soidc_task_group setup_group: - func: fetch source - - func: prepare resources - - func: fix absolute paths - - func: make files executable + - func: setup system - command: ec2.assume_role params: role_arn: ${aws_test_secrets_role} @@ -835,9 +803,7 @@ task_groups: - name: testoidc_task_group setup_group: - func: fetch source - - func: prepare resources - - func: fix absolute paths - - func: make files executable + - func: setup system - func: "assume ec2 role" - command: subprocess.exec params: @@ -859,7 +825,7 @@ task_groups: - name: test_aws_lambda_task_group setup_group: - func: fetch source - - func: prepare resources + - func: setup system - func: setup atlas teardown_task: - func: teardown atlas @@ -871,9 +837,7 @@ task_groups: - name: test_atlas_task_group_search_indexes setup_group: - func: fetch source - - func: prepare resources - - func: fix absolute paths - - func: make files executable + - func: setup system - func: setup atlas teardown_task: - func: teardown atlas @@ -1584,7 +1548,7 @@ tasks: - name: testazurekms-fail-task commands: - func: fetch source - - func: make files executable + - func: setup system - func: "bootstrap mongo-orchestration" vars: VERSION: "latest" @@ -1640,6 +1604,7 @@ tasks: params: binary: bash working_dir: src + include_expansions_in_env: ["PYTHON_BINARY"] args: - .evergreen/scripts/check-import-time.sh - ${revision} diff --git a/.evergreen/hatch.sh b/.evergreen/hatch.sh old mode 100644 new mode 100755 diff --git a/.evergreen/install-dependencies.sh b/.evergreen/install-dependencies.sh old mode 100644 new mode 100755 diff --git a/.evergreen/run-azurekms-fail-test.sh b/.evergreen/run-azurekms-fail-test.sh old mode 100644 new mode 100755 diff --git a/.evergreen/run-azurekms-test.sh b/.evergreen/run-azurekms-test.sh old mode 100644 new mode 100755 diff --git a/.evergreen/run-deployed-lambda-aws-tests.sh b/.evergreen/run-deployed-lambda-aws-tests.sh old mode 100644 new mode 100755 diff --git a/.evergreen/run-gcpkms-test.sh b/.evergreen/run-gcpkms-test.sh old mode 100644 new mode 100755 diff --git a/.evergreen/run-perf-tests.sh b/.evergreen/run-perf-tests.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/archive-mongodb-logs.sh b/.evergreen/scripts/archive-mongodb-logs.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/bootstrap-mongo-orchestration.sh b/.evergreen/scripts/bootstrap-mongo-orchestration.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/check-import-time.sh b/.evergreen/scripts/check-import-time.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/cleanup.sh b/.evergreen/scripts/cleanup.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/configure-env.sh b/.evergreen/scripts/configure-env.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/download-and-merge-coverage.sh b/.evergreen/scripts/download-and-merge-coverage.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/fix-absolute-paths.sh b/.evergreen/scripts/fix-absolute-paths.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/init-test-results.sh b/.evergreen/scripts/init-test-results.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/install-dependencies.sh b/.evergreen/scripts/install-dependencies.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/make-files-executable.sh b/.evergreen/scripts/make-files-executable.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/prepare-resources.sh b/.evergreen/scripts/prepare-resources.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/run-atlas-tests.sh b/.evergreen/scripts/run-atlas-tests.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/run-aws-ecs-auth-test.sh b/.evergreen/scripts/run-aws-ecs-auth-test.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/run-doctests.sh b/.evergreen/scripts/run-doctests.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/run-enterprise-auth-tests.sh b/.evergreen/scripts/run-enterprise-auth-tests.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/run-gcpkms-fail-test.sh b/.evergreen/scripts/run-gcpkms-fail-test.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/run-getdata.sh b/.evergreen/scripts/run-getdata.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/run-load-balancer.sh b/.evergreen/scripts/run-load-balancer.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/run-mockupdb-tests.sh b/.evergreen/scripts/run-mockupdb-tests.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/run-mod-wsgi-tests.sh b/.evergreen/scripts/run-mod-wsgi-tests.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/run-ocsp-test.sh b/.evergreen/scripts/run-ocsp-test.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/run-perf-tests.sh b/.evergreen/scripts/run-perf-tests.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/run-tests.sh b/.evergreen/scripts/run-tests.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/run-with-env.sh b/.evergreen/scripts/run-with-env.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/setup-encryption.sh b/.evergreen/scripts/setup-encryption.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/setup-tests.sh b/.evergreen/scripts/setup-tests.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/stop-load-balancer.sh b/.evergreen/scripts/stop-load-balancer.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/teardown-aws.sh b/.evergreen/scripts/teardown-aws.sh deleted file mode 100644 index 634d1e572..000000000 --- a/.evergreen/scripts/teardown-aws.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -cd "${DRIVERS_TOOLS}/.evergreen/auth_aws" || exit -if [ -f "./aws_e2e_setup.json" ]; then - . ./activate-authawsvenv.sh - python ./lib/aws_assign_instance_profile.py -fi diff --git a/.evergreen/scripts/teardown-docker.sh b/.evergreen/scripts/teardown-docker.sh deleted file mode 100644 index 733779d05..000000000 --- a/.evergreen/scripts/teardown-docker.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -# Remove all Docker images -DOCKER=$(command -v docker) || true -if [ -n "$DOCKER" ]; then - docker rmi -f "$(docker images -a -q)" &> /dev/null || true -fi diff --git a/.evergreen/scripts/upload-coverage-report.sh b/.evergreen/scripts/upload-coverage-report.sh old mode 100644 new mode 100755 diff --git a/.evergreen/scripts/windows-fix.sh b/.evergreen/scripts/windows-fix.sh old mode 100644 new mode 100755 diff --git a/.evergreen/setup-encryption.sh b/.evergreen/setup-encryption.sh old mode 100644 new mode 100755 diff --git a/.evergreen/teardown-encryption.sh b/.evergreen/teardown-encryption.sh old mode 100644 new mode 100755 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e2b497e5..4f6759bc5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -102,3 +102,15 @@ repos: # - test/versioned-api/crud-api-version-1-strict.json:514: nin ==> inn, min, bin, nine # - test/test_client.py:188: te ==> the, be, we, to args: ["-L", "fle,fo,infinit,isnt,nin,te,aks"] + +- repo: local + hooks: + - id: executable-shell + name: executable-shell + entry: chmod +x + language: system + types: [shell] + exclude: | + (?x)( + .evergreen/retry-with-backoff.sh + ) diff --git a/tools/synchro.sh b/tools/synchro.sh old mode 100644 new mode 100755 From 0f61ebb1150266a8cc9df70b87c4175b5f23aead Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 2 Dec 2024 12:35:31 -0500 Subject: [PATCH 8/9] PYTHON-4995 - Skip TestNoSessionsSupport tests on crypt_shared (#2022) --- test/asynchronous/test_encryption.py | 2 +- test/test_encryption.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 048db2d50..21cd5e266 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -3069,11 +3069,11 @@ def start_mongocryptd(port) -> None: _spawn_daemon(args) +@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") class TestNoSessionsSupport(AsyncEncryptionIntegrationTest): mongocryptd_client: AsyncMongoClient MONGOCRYPTD_PORT = 27020 - @unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") async def asyncSetUp(self) -> None: await super().asyncSetUp() start_mongocryptd(self.MONGOCRYPTD_PORT) diff --git a/test/test_encryption.py b/test/test_encryption.py index cb8bcb74d..18e21fe6a 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -3051,11 +3051,11 @@ def start_mongocryptd(port) -> None: _spawn_daemon(args) +@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") class TestNoSessionsSupport(EncryptionIntegrationTest): mongocryptd_client: MongoClient MONGOCRYPTD_PORT = 27020 - @unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") def setUp(self) -> None: super().setUp() start_mongocryptd(self.MONGOCRYPTD_PORT) From a9e61f6bed71dbf9a26d46d18d9905a6a0ccdd16 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 2 Dec 2024 10:08:52 -0800 Subject: [PATCH 9/9] PYTHON-4292 Improve TLS read performance (#2020) --- pymongo/network_layer.py | 89 +++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 52 deletions(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 6ab6db2f7..beffba6d1 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -28,7 +28,7 @@ from typing import ( Union, ) -from pymongo import _csot, ssl_support +from pymongo import ssl_support from pymongo._asyncio_task import create_task from pymongo.errors import _OperationCancelled from pymongo.socket_checker import _errno_from_exception @@ -316,62 +316,47 @@ async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLo return mv -# Sync version: -def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: - """Block until at least one byte is read, or a timeout, or a cancel.""" - sock = conn.conn - timed_out = False - # Check if the connection's socket has been manually closed - if sock.fileno() == -1: - return - while True: - # SSLSocket can have buffered data which won't be caught by select. - if hasattr(sock, "pending") and sock.pending() > 0: - readable = True - else: - # Wait up to 500ms for the socket to become readable and then - # check for cancellation. - if deadline: - remaining = deadline - time.monotonic() - # When the timeout has expired perform one final check to - # see if the socket is readable. This helps avoid spurious - # timeouts on AWS Lambda and other FaaS environments. - if remaining <= 0: - timed_out = True - timeout = max(min(remaining, _POLL_TIMEOUT), 0) - else: - timeout = _POLL_TIMEOUT - readable = conn.socket_checker.select(sock, read=True, timeout=timeout) - if conn.cancel_context.cancelled: - raise _OperationCancelled("operation cancelled") - if readable: - return - if timed_out: - raise socket.timeout("timed out") - - def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: buf = bytearray(length) mv = memoryview(buf) bytes_read = 0 - while bytes_read < length: - try: - wait_for_read(conn, deadline) - # CSOT: Update timeout. When the timeout has expired perform one - # final non-blocking recv. This helps avoid spurious timeouts when - # the response is actually already buffered on the client. - if _csot.get_timeout() and deadline is not None: - conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) - chunk_length = conn.conn.recv_into(mv[bytes_read:]) - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - except OSError as exc: - if _errno_from_exception(exc) == errno.EINTR: + # To support cancelling a network read, we shorten the socket timeout and + # check for the cancellation signal after each timeout. Alternatively we + # could close the socket but that does not reliably cancel recv() calls + # on all OSes. + orig_timeout = conn.conn.gettimeout() + try: + while bytes_read < length: + if deadline is not None: + # CSOT: Update timeout. When the timeout has expired perform one + # final non-blocking recv. This helps avoid spurious timeouts when + # the response is actually already buffered on the client. + short_timeout = min(max(deadline - time.monotonic(), 0), _POLL_TIMEOUT) + else: + short_timeout = _POLL_TIMEOUT + conn.set_conn_timeout(short_timeout) + try: + chunk_length = conn.conn.recv_into(mv[bytes_read:]) + except BLOCKING_IO_ERRORS: + if conn.cancel_context.cancelled: + raise _OperationCancelled("operation cancelled") from None + # We reached the true deadline. + raise socket.timeout("timed out") from None + except socket.timeout: + if conn.cancel_context.cancelled: + raise _OperationCancelled("operation cancelled") from None continue - raise - if chunk_length == 0: - raise OSError("connection closed") + except OSError as exc: + if conn.cancel_context.cancelled: + raise _OperationCancelled("operation cancelled") from None + if _errno_from_exception(exc) == errno.EINTR: + continue + raise + if chunk_length == 0: + raise OSError("connection closed") - bytes_read += chunk_length + bytes_read += chunk_length + finally: + conn.set_conn_timeout(orig_timeout) return mv