Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
c14943ab72
0
.evergreen/combine-coverage.sh
Normal file → Executable file
0
.evergreen/combine-coverage.sh
Normal file → Executable file
File diff suppressed because it is too large
Load Diff
2
.evergreen/hatch.sh
Normal file → Executable file
2
.evergreen/hatch.sh
Normal file → Executable file
@ -29,7 +29,7 @@ else # Set up virtualenv before installing hatch
|
||||
# Ensure hatch does not write to user or global locations.
|
||||
touch hatch_config.toml
|
||||
HATCH_CONFIG=$(pwd)/hatch_config.toml
|
||||
if [ "Windows_NT" = "$OS" ]; then # Magic variable in cygwin
|
||||
if [ "Windows_NT" = "${OS:-}" ]; then # Magic variable in cygwin
|
||||
HATCH_CONFIG=$(cygpath -m "$HATCH_CONFIG")
|
||||
fi
|
||||
export HATCH_CONFIG
|
||||
|
||||
0
.evergreen/install-dependencies.sh
Normal file → Executable file
0
.evergreen/install-dependencies.sh
Normal file → Executable file
0
.evergreen/run-azurekms-fail-test.sh
Normal file → Executable file
0
.evergreen/run-azurekms-fail-test.sh
Normal file → Executable file
0
.evergreen/run-azurekms-test.sh
Normal file → Executable file
0
.evergreen/run-azurekms-test.sh
Normal file → Executable file
0
.evergreen/run-deployed-lambda-aws-tests.sh
Normal file → Executable file
0
.evergreen/run-deployed-lambda-aws-tests.sh
Normal file → Executable file
0
.evergreen/run-gcpkms-test.sh
Normal file → Executable file
0
.evergreen/run-gcpkms-test.sh
Normal file → Executable file
0
.evergreen/run-perf-tests.sh
Normal file → Executable file
0
.evergreen/run-perf-tests.sh
Normal file → Executable file
@ -38,6 +38,7 @@ export PIP_PREFER_BINARY=1 # Prefer binary dists by default
|
||||
|
||||
set +x
|
||||
python -c "import sys; sys.exit(sys.prefix == sys.base_prefix)" || (echo "Not inside a virtual env!"; exit 1)
|
||||
PYTHON_IMPL=$(python -c "import platform; print(platform.python_implementation())")
|
||||
|
||||
# Try to source local Drivers Secrets
|
||||
if [ -f ./secrets-export.sh ]; then
|
||||
@ -47,19 +48,24 @@ else
|
||||
echo "Not sourcing secrets"
|
||||
fi
|
||||
|
||||
# Ensure C extensions have compiled.
|
||||
if [ -z "${NO_EXT:-}" ] && [ "$PYTHON_IMPL" = "CPython" ]; then
|
||||
python tools/fail_if_no_c.py
|
||||
fi
|
||||
|
||||
if [ "$AUTH" != "noauth" ]; then
|
||||
if [ ! -z "$TEST_DATA_LAKE" ]; then
|
||||
if [ -n "$TEST_DATA_LAKE" ]; then
|
||||
export DB_USER="mhuser"
|
||||
export DB_PASSWORD="pencil"
|
||||
elif [ ! -z "$TEST_SERVERLESS" ]; then
|
||||
source ${DRIVERS_TOOLS}/.evergreen/serverless/secrets-export.sh
|
||||
elif [ -n "$TEST_SERVERLESS" ]; then
|
||||
source "${DRIVERS_TOOLS}"/.evergreen/serverless/secrets-export.sh
|
||||
export DB_USER=$SERVERLESS_ATLAS_USER
|
||||
export DB_PASSWORD=$SERVERLESS_ATLAS_PASSWORD
|
||||
export MONGODB_URI="$SERVERLESS_URI"
|
||||
echo "MONGODB_URI=$MONGODB_URI"
|
||||
export SINGLE_MONGOS_LB_URI=$MONGODB_URI
|
||||
export MULTI_MONGOS_LB_URI=$MONGODB_URI
|
||||
elif [ ! -z "$TEST_AUTH_OIDC" ]; then
|
||||
elif [ -n "$TEST_AUTH_OIDC" ]; then
|
||||
export DB_USER=$OIDC_ADMIN_USER
|
||||
export DB_PASSWORD=$OIDC_ADMIN_PWD
|
||||
export DB_IP="$MONGODB_URI"
|
||||
@ -240,7 +246,6 @@ python -c 'import sys; print(sys.version)'
|
||||
|
||||
# Run the tests with coverage if requested and coverage is installed.
|
||||
# Only cover CPython. PyPy reports suspiciously low coverage.
|
||||
PYTHON_IMPL=$(python -c "import platform; print(platform.python_implementation())")
|
||||
if [ -n "$COVERAGE" ] && [ "$PYTHON_IMPL" = "CPython" ]; then
|
||||
# Keep in sync with combine-coverage.sh.
|
||||
# coverage >=5 is needed for relative_files=true.
|
||||
|
||||
8
.evergreen/scripts/archive-mongodb-logs.sh
Executable file
8
.evergreen/scripts/archive-mongodb-logs.sh
Executable file
@ -0,0 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -o xtrace
|
||||
mkdir out_dir
|
||||
# shellcheck disable=SC2156
|
||||
find "$MONGO_ORCHESTRATION_HOME" -name \*.log -exec sh -c 'x="{}"; mv $x $PWD/out_dir/$(basename $(dirname $x))_$(basename $x)' \;
|
||||
tar zcvf mongodb-logs.tar.gz -C out_dir/ .
|
||||
rm -rf out_dir
|
||||
46
.evergreen/scripts/bootstrap-mongo-orchestration.sh
Executable file
46
.evergreen/scripts/bootstrap-mongo-orchestration.sh
Executable file
@ -0,0 +1,46 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -o xtrace
|
||||
|
||||
# Enable core dumps if enabled on the machine
|
||||
# Copied from https://github.com/mongodb/mongo/blob/master/etc/evergreen.yml
|
||||
if [ -f /proc/self/coredump_filter ]; then
|
||||
# Set the shell process (and its children processes) to dump ELF headers (bit 4),
|
||||
# anonymous shared mappings (bit 1), and anonymous private mappings (bit 0).
|
||||
echo 0x13 >/proc/self/coredump_filter
|
||||
|
||||
if [ -f /sbin/sysctl ]; then
|
||||
# Check that the core pattern is set explicitly on our distro image instead
|
||||
# of being the OS's default value. This ensures that coredump names are consistent
|
||||
# across distros and can be picked up by Evergreen.
|
||||
core_pattern=$(/sbin/sysctl -n "kernel.core_pattern")
|
||||
if [ "$core_pattern" = "dump_%e.%p.core" ]; then
|
||||
echo "Enabling coredumps"
|
||||
ulimit -c unlimited
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$(uname -s)" = "Darwin" ]; then
|
||||
core_pattern_mac=$(/usr/sbin/sysctl -n "kern.corefile")
|
||||
if [ "$core_pattern_mac" = "dump_%N.%P.core" ]; then
|
||||
echo "Enabling coredumps"
|
||||
ulimit -c unlimited
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -n "${skip_crypt_shared}" ]; then
|
||||
export SKIP_CRYPT_SHARED=1
|
||||
fi
|
||||
|
||||
MONGODB_VERSION=${VERSION} \
|
||||
TOPOLOGY=${TOPOLOGY} \
|
||||
AUTH=${AUTH:-noauth} \
|
||||
SSL=${SSL:-nossl} \
|
||||
STORAGE_ENGINE=${STORAGE_ENGINE:-} \
|
||||
DISABLE_TEST_COMMANDS=${DISABLE_TEST_COMMANDS:-} \
|
||||
ORCHESTRATION_FILE=${ORCHESTRATION_FILE:-} \
|
||||
REQUIRE_API_VERSION=${REQUIRE_API_VERSION:-} \
|
||||
LOAD_BALANCER=${LOAD_BALANCER:-} \
|
||||
bash ${DRIVERS_TOOLS}/.evergreen/run-orchestration.sh
|
||||
# run-orchestration generates expansion file with the MONGODB_URI for the cluster
|
||||
7
.evergreen/scripts/check-import-time.sh
Executable file
7
.evergreen/scripts/check-import-time.sh
Executable file
@ -0,0 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
. .evergreen/scripts/env.sh
|
||||
set -x
|
||||
export BASE_SHA="$1"
|
||||
export HEAD_SHA="$2"
|
||||
bash .evergreen/run-import-time-test.sh
|
||||
7
.evergreen/scripts/cleanup.sh
Executable file
7
.evergreen/scripts/cleanup.sh
Executable file
@ -0,0 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ -f "$DRIVERS_TOOLS"/.evergreen/csfle/secrets-export.sh ]; then
|
||||
. .evergreen/hatch.sh encryption:teardown
|
||||
fi
|
||||
rm -rf "${DRIVERS_TOOLS}" || true
|
||||
rm -f ./secrets-export.sh || true
|
||||
19
.evergreen/scripts/configure-env.sh
Normal file → Executable file
19
.evergreen/scripts/configure-env.sh
Normal file → Executable file
@ -1,4 +1,4 @@
|
||||
#!/bin/bash -ex
|
||||
#!/bin/bash -eux
|
||||
|
||||
# Get the current unique version of this checkout
|
||||
# shellcheck disable=SC2154
|
||||
@ -29,7 +29,7 @@ fi
|
||||
export MONGO_ORCHESTRATION_HOME="$DRIVERS_TOOLS/.evergreen/orchestration"
|
||||
export MONGODB_BINARIES="$DRIVERS_TOOLS/mongodb/bin"
|
||||
|
||||
cat <<EOT > $SCRIPT_DIR/env.sh
|
||||
cat <<EOT > "$SCRIPT_DIR"/env.sh
|
||||
set -o errexit
|
||||
export PROJECT_DIRECTORY="$PROJECT_DIRECTORY"
|
||||
export CURRENT_VERSION="$CURRENT_VERSION"
|
||||
@ -38,6 +38,21 @@ export DRIVERS_TOOLS="$DRIVERS_TOOLS"
|
||||
export MONGO_ORCHESTRATION_HOME="$MONGO_ORCHESTRATION_HOME"
|
||||
export MONGODB_BINARIES="$MONGODB_BINARIES"
|
||||
export PROJECT_DIRECTORY="$PROJECT_DIRECTORY"
|
||||
export SETDEFAULTENCODING="${SETDEFAULTENCODING:-}"
|
||||
export SKIP_CSOT_TESTS="${SKIP_CSOT_TESTS:-}"
|
||||
export MONGODB_STARTED="${MONGODB_STARTED:-}"
|
||||
export DISABLE_TEST_COMMANDS="${DISABLE_TEST_COMMANDS:-}"
|
||||
export GREEN_FRAMEWORK="${GREEN_FRAMEWORK:-}"
|
||||
export NO_EXT="${NO_EXT:-}"
|
||||
export COVERAGE="${COVERAGE:-}"
|
||||
export COMPRESSORS="${COMPRESSORS:-}"
|
||||
export MONGODB_API_VERSION="${MONGODB_API_VERSION:-}"
|
||||
export SKIP_HATCH="${SKIP_HATCH:-}"
|
||||
export skip_crypt_shared="${skip_crypt_shared:-}"
|
||||
export STORAGE_ENGINE="${STORAGE_ENGINE:-}"
|
||||
export REQUIRE_API_VERSION="${REQUIRE_API_VERSION:-}"
|
||||
export skip_web_identity_auth_test="${skip_web_identity_auth_test:-}"
|
||||
export skip_ECS_auth_test="${skip_ECS_auth_test:-}"
|
||||
|
||||
export TMPDIR="$MONGO_ORCHESTRATION_HOME/db"
|
||||
export PATH="$MONGODB_BINARIES:$PATH"
|
||||
|
||||
4
.evergreen/scripts/download-and-merge-coverage.sh
Executable file
4
.evergreen/scripts/download-and-merge-coverage.sh
Executable file
@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Download all the task coverage files.
|
||||
aws s3 cp --recursive s3://"$1"/coverage/"$2"/"$3"/coverage/ coverage/
|
||||
8
.evergreen/scripts/fix-absolute-paths.sh
Executable file
8
.evergreen/scripts/fix-absolute-paths.sh
Executable file
@ -0,0 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
set +x
|
||||
. src/.evergreen/scripts/env.sh
|
||||
# shellcheck disable=SC2044
|
||||
for filename in $(find $DRIVERS_TOOLS -name \*.json); do
|
||||
perl -p -i -e "s|ABSOLUTE_PATH_REPLACEMENT_TOKEN|$DRIVERS_TOOLS|g" $filename
|
||||
done
|
||||
5
.evergreen/scripts/init-test-results.sh
Executable file
5
.evergreen/scripts/init-test-results.sh
Executable file
@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
set +x
|
||||
. src/.evergreen/scripts/env.sh
|
||||
echo '{"results": [{ "status": "FAIL", "test_file": "Build", "log_raw": "No test-results.json found was created" } ]}' >$PROJECT_DIRECTORY/test-results.json
|
||||
6
.evergreen/scripts/install-dependencies.sh
Executable file
6
.evergreen/scripts/install-dependencies.sh
Executable file
@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -o xtrace
|
||||
file="$PROJECT_DIRECTORY/.evergreen/install-dependencies.sh"
|
||||
# Don't use ${file} syntax here because evergreen treats it as an empty expansion.
|
||||
[ -f "$file" ] && bash "$file" || echo "$file not available, skipping"
|
||||
8
.evergreen/scripts/make-files-executable.sh
Executable file
8
.evergreen/scripts/make-files-executable.sh
Executable file
@ -0,0 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
set +x
|
||||
. src/.evergreen/scripts/env.sh
|
||||
# shellcheck disable=SC2044
|
||||
for i in $(find "$DRIVERS_TOOLS"/.evergreen "$PROJECT_DIRECTORY"/.evergreen -name \*.sh); do
|
||||
chmod +x "$i"
|
||||
done
|
||||
12
.evergreen/scripts/prepare-resources.sh
Executable file
12
.evergreen/scripts/prepare-resources.sh
Executable file
@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
. src/.evergreen/scripts/env.sh
|
||||
set -o xtrace
|
||||
rm -rf $DRIVERS_TOOLS
|
||||
if [ "$PROJECT" = "drivers-tools" ]; then
|
||||
# If this was a patch build, doing a fresh clone would not actually test the patch
|
||||
cp -R $PROJECT_DIRECTORY/ $DRIVERS_TOOLS
|
||||
else
|
||||
git clone https://github.com/mongodb-labs/drivers-evergreen-tools.git $DRIVERS_TOOLS
|
||||
fi
|
||||
echo "{ \"releases\": { \"default\": \"$MONGODB_BINARIES\" }}" >$MONGO_ORCHESTRATION_HOME/orchestration.config
|
||||
7
.evergreen/scripts/run-atlas-tests.sh
Executable file
7
.evergreen/scripts/run-atlas-tests.sh
Executable file
@ -0,0 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Disable xtrace for security reasons (just in case it was accidentally set).
|
||||
set +x
|
||||
set -o errexit
|
||||
bash "${DRIVERS_TOOLS}"/.evergreen/auth_aws/setup_secrets.sh drivers/atlas_connect
|
||||
TEST_ATLAS=1 bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-eg
|
||||
15
.evergreen/scripts/run-aws-ecs-auth-test.sh
Executable file
15
.evergreen/scripts/run-aws-ecs-auth-test.sh
Executable file
@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
# shellcheck disable=SC2154
|
||||
if [ "${skip_ECS_auth_test}" = "true" ]; then
|
||||
echo "This platform does not support the ECS auth test, skipping..."
|
||||
exit 0
|
||||
fi
|
||||
set -ex
|
||||
cd "$DRIVERS_TOOLS"/.evergreen/auth_aws
|
||||
. ./activate-authawsvenv.sh
|
||||
. aws_setup.sh ecs
|
||||
export MONGODB_BINARIES="$MONGODB_BINARIES"
|
||||
export PROJECT_DIRECTORY="$PROJECT_DIRECTORY"
|
||||
python aws_tester.py ecs
|
||||
cd -
|
||||
4
.evergreen/scripts/run-doctests.sh
Executable file
4
.evergreen/scripts/run-doctests.sh
Executable file
@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -o xtrace
|
||||
PYTHON_BINARY=${PYTHON_BINARY} bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh doctest:test
|
||||
6
.evergreen/scripts/run-enterprise-auth-tests.sh
Executable file
6
.evergreen/scripts/run-enterprise-auth-tests.sh
Executable file
@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Disable xtrace for security reasons (just in case it was accidentally set).
|
||||
set +x
|
||||
bash "${DRIVERS_TOOLS}"/.evergreen/auth_aws/setup_secrets.sh drivers/enterprise_auth
|
||||
TEST_ENTERPRISE_AUTH=1 AUTH=auth bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-eg
|
||||
7
.evergreen/scripts/run-gcpkms-fail-test.sh
Executable file
7
.evergreen/scripts/run-gcpkms-fail-test.sh
Executable file
@ -0,0 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
. .evergreen/scripts/env.sh
|
||||
export PYTHON_BINARY=/opt/mongodbtoolchain/v4/bin/python3
|
||||
export LIBMONGOCRYPT_URL=https://s3.amazonaws.com/mciuploads/libmongocrypt/debian11/master/latest/libmongocrypt.tar.gz
|
||||
SKIP_SERVERS=1 bash ./.evergreen/setup-encryption.sh
|
||||
SUCCESS=false TEST_FLE_GCP_AUTO=1 ./.evergreen/hatch.sh test:test-eg
|
||||
22
.evergreen/scripts/run-getdata.sh
Executable file
22
.evergreen/scripts/run-getdata.sh
Executable file
@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -o xtrace
|
||||
. ${DRIVERS_TOOLS}/.evergreen/download-mongodb.sh || true
|
||||
get_distro || true
|
||||
echo $DISTRO
|
||||
echo $MARCH
|
||||
echo $OS
|
||||
uname -a || true
|
||||
ls /etc/*release* || true
|
||||
cc --version || true
|
||||
gcc --version || true
|
||||
clang --version || true
|
||||
gcov --version || true
|
||||
lcov --version || true
|
||||
llvm-cov --version || true
|
||||
echo $PATH
|
||||
ls -la /usr/local/Cellar/llvm/*/bin/ || true
|
||||
ls -la /usr/local/Cellar/ || true
|
||||
scan-build --version || true
|
||||
genhtml --version || true
|
||||
valgrind --version || true
|
||||
3
.evergreen/scripts/run-load-balancer.sh
Executable file
3
.evergreen/scripts/run-load-balancer.sh
Executable file
@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
MONGODB_URI=${MONGODB_URI} bash "${DRIVERS_TOOLS}"/.evergreen/run-load-balancer.sh start
|
||||
5
.evergreen/scripts/run-mockupdb-tests.sh
Executable file
5
.evergreen/scripts/run-mockupdb-tests.sh
Executable file
@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -o xtrace
|
||||
export PYTHON_BINARY=${PYTHON_BINARY}
|
||||
bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-mockupdb
|
||||
2
.evergreen/run-mod-wsgi-tests.sh → .evergreen/scripts/run-mod-wsgi-tests.sh
Normal file → Executable file
2
.evergreen/run-mod-wsgi-tests.sh → .evergreen/scripts/run-mod-wsgi-tests.sh
Normal file → Executable file
@ -28,7 +28,7 @@ export MOD_WSGI_SO=/opt/python/mod_wsgi/python_version/$PYTHON_VERSION/mod_wsgi_
|
||||
export PYTHONHOME=/opt/python/$PYTHON_VERSION
|
||||
# If MOD_WSGI_EMBEDDED is set use the default embedded mode behavior instead
|
||||
# of daemon mode (WSGIDaemonProcess).
|
||||
if [ -n "$MOD_WSGI_EMBEDDED" ]; then
|
||||
if [ -n "${MOD_WSGI_EMBEDDED:-}" ]; then
|
||||
export MOD_WSGI_CONF=mod_wsgi_test_embedded.conf
|
||||
else
|
||||
export MOD_WSGI_CONF=mod_wsgi_test.conf
|
||||
@ -13,10 +13,16 @@ set -o errexit # Exit the script with error if any of the commands fail
|
||||
# mechanism.
|
||||
# PYTHON_BINARY The Python version to use.
|
||||
|
||||
echo "Running MONGODB-AWS authentication tests"
|
||||
# shellcheck disable=SC2154
|
||||
if [ "${skip_EC2_auth_test:-}" = "true" ] && { [ "$1" = "ec2" ] || [ "$1" = "web-identity" ]; }; then
|
||||
echo "This platform does not support the EC2 auth test, skipping..."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Running MONGODB-AWS authentication tests for $1"
|
||||
|
||||
# Handle credentials and environment setup.
|
||||
. $DRIVERS_TOOLS/.evergreen/auth_aws/aws_setup.sh $1
|
||||
. "$DRIVERS_TOOLS"/.evergreen/auth_aws/aws_setup.sh "$1"
|
||||
|
||||
# show test output
|
||||
set -x
|
||||
8
.evergreen/scripts/run-ocsp-test.sh
Executable file
8
.evergreen/scripts/run-ocsp-test.sh
Executable file
@ -0,0 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
TEST_OCSP=1 \
|
||||
PYTHON_BINARY="${PYTHON_BINARY}" \
|
||||
CA_FILE="${DRIVERS_TOOLS}/.evergreen/ocsp/${OCSP_ALGORITHM}/ca.pem" \
|
||||
OCSP_TLS_SHOULD_SUCCEED="${OCSP_TLS_SHOULD_SUCCEED}" \
|
||||
bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-eg
|
||||
bash "${DRIVERS_TOOLS}"/.evergreen/ocsp/teardown.sh
|
||||
4
.evergreen/scripts/run-perf-tests.sh
Executable file
4
.evergreen/scripts/run-perf-tests.sh
Executable file
@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
|
||||
PROJECT_DIRECTORY=${PROJECT_DIRECTORY}
|
||||
bash "${PROJECT_DIRECTORY}"/.evergreen/run-perf-tests.sh
|
||||
55
.evergreen/scripts/run-tests.sh
Executable file
55
.evergreen/scripts/run-tests.sh
Executable file
@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Disable xtrace
|
||||
set +x
|
||||
if [ -n "${MONGODB_STARTED}" ]; then
|
||||
export PYMONGO_MUST_CONNECT=true
|
||||
fi
|
||||
if [ -n "${DISABLE_TEST_COMMANDS}" ]; then
|
||||
export PYMONGO_DISABLE_TEST_COMMANDS=1
|
||||
fi
|
||||
if [ -n "${test_encryption}" ]; then
|
||||
# Disable xtrace (just in case it was accidentally set).
|
||||
set +x
|
||||
bash "${DRIVERS_TOOLS}"/.evergreen/csfle/await-servers.sh
|
||||
export TEST_ENCRYPTION=1
|
||||
if [ -n "${test_encryption_pyopenssl}" ]; then
|
||||
export TEST_ENCRYPTION_PYOPENSSL=1
|
||||
fi
|
||||
fi
|
||||
if [ -n "${test_crypt_shared}" ]; then
|
||||
export TEST_CRYPT_SHARED=1
|
||||
export CRYPT_SHARED_LIB_PATH=${CRYPT_SHARED_LIB_PATH}
|
||||
fi
|
||||
if [ -n "${test_pyopenssl}" ]; then
|
||||
export TEST_PYOPENSSL=1
|
||||
fi
|
||||
if [ -n "${SETDEFAULTENCODING}" ]; then
|
||||
export SETDEFAULTENCODING="${SETDEFAULTENCODING}"
|
||||
fi
|
||||
if [ -n "${test_loadbalancer}" ]; then
|
||||
export TEST_LOADBALANCER=1
|
||||
export SINGLE_MONGOS_LB_URI="${SINGLE_MONGOS_LB_URI}"
|
||||
export MULTI_MONGOS_LB_URI="${MULTI_MONGOS_LB_URI}"
|
||||
fi
|
||||
if [ -n "${test_serverless}" ]; then
|
||||
export TEST_SERVERLESS=1
|
||||
fi
|
||||
if [ -n "${TEST_INDEX_MANAGEMENT:-}" ]; then
|
||||
export TEST_INDEX_MANAGEMENT=1
|
||||
fi
|
||||
if [ -n "${SKIP_CSOT_TESTS}" ]; then
|
||||
export SKIP_CSOT_TESTS=1
|
||||
fi
|
||||
GREEN_FRAMEWORK=${GREEN_FRAMEWORK} \
|
||||
PYTHON_BINARY=${PYTHON_BINARY} \
|
||||
NO_EXT=${NO_EXT} \
|
||||
COVERAGE=${COVERAGE} \
|
||||
COMPRESSORS=${COMPRESSORS} \
|
||||
AUTH=${AUTH} \
|
||||
SSL=${SSL} \
|
||||
TEST_DATA_LAKE=${TEST_DATA_LAKE:-} \
|
||||
TEST_SUITES=${TEST_SUITES:-} \
|
||||
MONGODB_API_VERSION=${MONGODB_API_VERSION} \
|
||||
SKIP_HATCH=${SKIP_HATCH} \
|
||||
bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-eg
|
||||
21
.evergreen/scripts/run-with-env.sh
Executable file
21
.evergreen/scripts/run-with-env.sh
Executable file
@ -0,0 +1,21 @@
|
||||
#!/bin/bash -eu
|
||||
|
||||
# Example use: bash run-with-env.sh run-tests.sh {args...}
|
||||
|
||||
# Parameter expansion to get just the current directory's name
|
||||
if [ "${PWD##*/}" == "src" ]; then
|
||||
. .evergreen/scripts/env.sh
|
||||
if [ -f ".evergreen/scripts/test-env.sh" ]; then
|
||||
. .evergreen/scripts/test-env.sh
|
||||
fi
|
||||
else
|
||||
. src/.evergreen/scripts/env.sh
|
||||
if [ -f "src/.evergreen/scripts/test-env.sh" ]; then
|
||||
. src/.evergreen/scripts/test-env.sh
|
||||
fi
|
||||
fi
|
||||
|
||||
set -eu
|
||||
|
||||
# shellcheck source=/dev/null
|
||||
. "$@"
|
||||
5
.evergreen/scripts/setup-encryption.sh
Executable file
5
.evergreen/scripts/setup-encryption.sh
Executable file
@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ -n "${test_encryption}" ]; then
|
||||
./.evergreen/hatch.sh encryption:setup
|
||||
fi
|
||||
27
.evergreen/scripts/setup-tests.sh
Executable file
27
.evergreen/scripts/setup-tests.sh
Executable file
@ -0,0 +1,27 @@
|
||||
#!/bin/bash -eux
|
||||
|
||||
PROJECT_DIRECTORY="$(pwd)"
|
||||
SCRIPT_DIR="$PROJECT_DIRECTORY/.evergreen/scripts"
|
||||
|
||||
if [ -f "$SCRIPT_DIR/test-env.sh" ]; then
|
||||
echo "Reading $SCRIPT_DIR/test-env.sh file"
|
||||
. "$SCRIPT_DIR/test-env.sh"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
cat <<EOT > "$SCRIPT_DIR"/test-env.sh
|
||||
export test_encryption="${test_encryption:-}"
|
||||
export test_encryption_pyopenssl="${test_encryption_pyopenssl:-}"
|
||||
export test_crypt_shared="${test_crypt_shared:-}"
|
||||
export test_pyopenssl="${test_pyopenssl:-}"
|
||||
export test_loadbalancer="${test_loadbalancer:-}"
|
||||
export test_serverless="${test_serverless:-}"
|
||||
export TEST_INDEX_MANAGEMENT="${TEST_INDEX_MANAGEMENT:-}"
|
||||
export TEST_DATA_LAKE="${TEST_DATA_LAKE:-}"
|
||||
export ORCHESTRATION_FILE="${ORCHESTRATION_FILE:-}"
|
||||
export AUTH="${AUTH:-noauth}"
|
||||
export SSL="${SSL:-nossl}"
|
||||
export PYTHON_BINARY="${PYTHON_BINARY:-}"
|
||||
EOT
|
||||
|
||||
chmod +x "$SCRIPT_DIR"/test-env.sh
|
||||
5
.evergreen/scripts/stop-load-balancer.sh
Executable file
5
.evergreen/scripts/stop-load-balancer.sh
Executable file
@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd "${DRIVERS_TOOLS}"/.evergreen || exit
|
||||
DRIVERS_TOOLS=${DRIVERS_TOOLS}
|
||||
bash "${DRIVERS_TOOLS}"/.evergreen/run-load-balancer.sh stop
|
||||
3
.evergreen/scripts/upload-coverage-report.sh
Executable file
3
.evergreen/scripts/upload-coverage-report.sh
Executable file
@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
aws s3 cp htmlcov/ s3://"$1"/coverage/"$2"/"$3"/htmlcov/ --recursive --acl public-read --region us-east-1
|
||||
11
.evergreen/scripts/windows-fix.sh
Executable file
11
.evergreen/scripts/windows-fix.sh
Executable file
@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
set +x
|
||||
. src/.evergreen/scripts/env.sh
|
||||
# shellcheck disable=SC2044
|
||||
for i in $(find "$DRIVERS_TOOLS"/.evergreen "$PROJECT_DIRECTORY"/.evergreen -name \*.sh); do
|
||||
< "$i" tr -d '\r' >"$i".new
|
||||
mv "$i".new "$i"
|
||||
done
|
||||
# Copy client certificate because symlinks do not work on Windows.
|
||||
cp "$DRIVERS_TOOLS"/.evergreen/x509gen/client.pem "$MONGO_ORCHESTRATION_HOME"/lib/client.pem
|
||||
7
.evergreen/setup-encryption.sh
Normal file → Executable file
7
.evergreen/setup-encryption.sh
Normal file → Executable file
@ -52,6 +52,9 @@ ls -la libmongocrypt
|
||||
ls -la libmongocrypt/nocrypto
|
||||
|
||||
if [ -z "${SKIP_SERVERS:-}" ]; then
|
||||
bash ${DRIVERS_TOOLS}/.evergreen/csfle/setup-secrets.sh
|
||||
bash ${DRIVERS_TOOLS}/.evergreen/csfle/start-servers.sh
|
||||
PYTHON_BINARY_OLD=${PYTHON_BINARY}
|
||||
export PYTHON_BINARY=""
|
||||
bash "${DRIVERS_TOOLS}"/.evergreen/csfle/setup-secrets.sh
|
||||
export PYTHON_BINARY=$PYTHON_BINARY_OLD
|
||||
bash "${DRIVERS_TOOLS}"/.evergreen/csfle/start-servers.sh
|
||||
fi
|
||||
|
||||
0
.evergreen/teardown-encryption.sh
Normal file → Executable file
0
.evergreen/teardown-encryption.sh
Normal file → Executable file
@ -17,7 +17,7 @@ find_python3() {
|
||||
elif [ -d "/Library/Frameworks/Python.Framework/Versions/3.9" ]; then
|
||||
PYTHON="/Library/Frameworks/Python.Framework/Versions/3.9/bin/python3"
|
||||
fi
|
||||
elif [ "Windows_NT" = "$OS" ]; then # Magic variable in cygwin
|
||||
elif [ "Windows_NT" = "${OS:-}" ]; then # Magic variable in cygwin
|
||||
PYTHON="C:/python/Python39/python.exe"
|
||||
else
|
||||
# Prefer our own toolchain, fall back to mongodb toolchain if it has Python 3.9+.
|
||||
@ -56,7 +56,7 @@ createvirtualenv () {
|
||||
# Workaround for bug in older versions of virtualenv.
|
||||
$VIRTUALENV $VENVPATH 2>/dev/null || $VIRTUALENV $VENVPATH
|
||||
fi
|
||||
if [ "Windows_NT" = "$OS" ]; then
|
||||
if [ "Windows_NT" = "${OS:-}" ]; then
|
||||
# Workaround https://bugs.python.org/issue32451:
|
||||
# mongovenv/Scripts/activate: line 3: $'\r': command not found
|
||||
dos2unix $VENVPATH/Scripts/activate || true
|
||||
@ -78,6 +78,7 @@ testinstall () {
|
||||
PYTHON=$1
|
||||
RELEASE=$2
|
||||
NO_VIRTUALENV=$3
|
||||
PYTHON_IMPL=$(python -c "import platform; print(platform.python_implementation())")
|
||||
|
||||
if [ -z "$NO_VIRTUALENV" ]; then
|
||||
createvirtualenv $PYTHON venvtestinstall
|
||||
@ -86,7 +87,11 @@ testinstall () {
|
||||
|
||||
$PYTHON -m pip install --upgrade $RELEASE
|
||||
cd tools
|
||||
$PYTHON fail_if_no_c.py
|
||||
|
||||
if [ "$PYTHON_IMPL" = "CPython" ]; then
|
||||
$PYTHON fail_if_no_c.py
|
||||
fi
|
||||
|
||||
$PYTHON -m pip uninstall -y pymongo
|
||||
cd ..
|
||||
|
||||
|
||||
@ -102,3 +102,15 @@ repos:
|
||||
# - test/versioned-api/crud-api-version-1-strict.json:514: nin ==> inn, min, bin, nine
|
||||
# - test/test_client.py:188: te ==> the, be, we, to
|
||||
args: ["-L", "fle,fo,infinit,isnt,nin,te,aks"]
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: executable-shell
|
||||
name: executable-shell
|
||||
entry: chmod +x
|
||||
language: system
|
||||
types: [shell]
|
||||
exclude: |
|
||||
(?x)(
|
||||
.evergreen/retry-with-backoff.sh
|
||||
)
|
||||
|
||||
@ -38,3 +38,61 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
|
||||
2) License Notice for _asyncio_lock.py
|
||||
-----------------------------------------
|
||||
|
||||
1. This LICENSE AGREEMENT is between the Python Software Foundation
|
||||
("PSF"), and the Individual or Organization ("Licensee") accessing and
|
||||
otherwise using this software ("Python") in source or binary form and
|
||||
its associated documentation.
|
||||
|
||||
2. Subject to the terms and conditions of this License Agreement, PSF hereby
|
||||
grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
|
||||
analyze, test, perform and/or display publicly, prepare derivative works,
|
||||
distribute, and otherwise use Python alone or in any derivative version,
|
||||
provided, however, that PSF's License Agreement and PSF's notice of copyright,
|
||||
i.e., "Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved"
|
||||
are retained in Python alone or in any derivative version prepared by Licensee.
|
||||
|
||||
3. In the event Licensee prepares a derivative work that is based on
|
||||
or incorporates Python or any part thereof, and wants to make
|
||||
the derivative work available to others as provided herein, then
|
||||
Licensee hereby agrees to include in any such work a brief summary of
|
||||
the changes made to Python.
|
||||
|
||||
4. PSF is making Python available to Licensee on an "AS IS"
|
||||
basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
|
||||
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
|
||||
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
|
||||
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
|
||||
INFRINGE ANY THIRD PARTY RIGHTS.
|
||||
|
||||
5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
|
||||
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
|
||||
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
|
||||
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
|
||||
|
||||
6. This License Agreement will automatically terminate upon a material
|
||||
breach of its terms and conditions.
|
||||
|
||||
7. Nothing in this License Agreement shall be deemed to create any
|
||||
relationship of agency, partnership, or joint venture between PSF and
|
||||
Licensee. This License Agreement does not grant permission to use PSF
|
||||
trademarks or trade name in a trademark sense to endorse or promote
|
||||
products or services of Licensee, or any third party.
|
||||
|
||||
8. By copying, installing or otherwise using Python, Licensee
|
||||
agrees to be bound by the terms and conditions of this License
|
||||
Agreement.
|
||||
|
||||
Permission to use, copy, modify, and/or distribute this software for any
|
||||
purpose with or without fee is hereby granted.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
|
||||
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
|
||||
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
|
||||
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
|
||||
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
|
||||
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
|
||||
PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
309
pymongo/_asyncio_lock.py
Normal file
309
pymongo/_asyncio_lock.py
Normal file
@ -0,0 +1,309 @@
|
||||
# Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved
|
||||
|
||||
"""Lock and Condition classes vendored from https://github.com/python/cpython/blob/main/Lib/asyncio/locks.py
|
||||
to port 3.13 fixes to older versions of Python.
|
||||
Can be removed once we drop Python 3.12 support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import threading
|
||||
from asyncio import events, exceptions
|
||||
from typing import Any, Coroutine, Optional
|
||||
|
||||
_global_lock = threading.Lock()
|
||||
|
||||
|
||||
class _LoopBoundMixin:
|
||||
_loop = None
|
||||
|
||||
def _get_loop(self) -> Any:
|
||||
loop = events._get_running_loop()
|
||||
|
||||
if self._loop is None:
|
||||
with _global_lock:
|
||||
if self._loop is None:
|
||||
self._loop = loop
|
||||
if loop is not self._loop:
|
||||
raise RuntimeError(f"{self!r} is bound to a different event loop")
|
||||
return loop
|
||||
|
||||
|
||||
class _ContextManagerMixin:
|
||||
async def __aenter__(self) -> None:
|
||||
await self.acquire() # type: ignore[attr-defined]
|
||||
# We have no use for the "as ..." clause in the with
|
||||
# statement for locks.
|
||||
return
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.release() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class Lock(_ContextManagerMixin, _LoopBoundMixin):
|
||||
"""Primitive lock objects.
|
||||
|
||||
A primitive lock is a synchronization primitive that is not owned
|
||||
by a particular task when locked. A primitive lock is in one
|
||||
of two states, 'locked' or 'unlocked'.
|
||||
|
||||
It is created in the unlocked state. It has two basic methods,
|
||||
acquire() and release(). When the state is unlocked, acquire()
|
||||
changes the state to locked and returns immediately. When the
|
||||
state is locked, acquire() blocks until a call to release() in
|
||||
another task changes it to unlocked, then the acquire() call
|
||||
resets it to locked and returns. The release() method should only
|
||||
be called in the locked state; it changes the state to unlocked
|
||||
and returns immediately. If an attempt is made to release an
|
||||
unlocked lock, a RuntimeError will be raised.
|
||||
|
||||
When more than one task is blocked in acquire() waiting for
|
||||
the state to turn to unlocked, only one task proceeds when a
|
||||
release() call resets the state to unlocked; successive release()
|
||||
calls will unblock tasks in FIFO order.
|
||||
|
||||
Locks also support the asynchronous context management protocol.
|
||||
'async with lock' statement should be used.
|
||||
|
||||
Usage:
|
||||
|
||||
lock = Lock()
|
||||
...
|
||||
await lock.acquire()
|
||||
try:
|
||||
...
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
Context manager usage:
|
||||
|
||||
lock = Lock()
|
||||
...
|
||||
async with lock:
|
||||
...
|
||||
|
||||
Lock objects can be tested for locking state:
|
||||
|
||||
if not lock.locked():
|
||||
await lock.acquire()
|
||||
else:
|
||||
# lock is acquired
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._waiters: Optional[collections.deque] = None
|
||||
self._locked = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
res = super().__repr__()
|
||||
extra = "locked" if self._locked else "unlocked"
|
||||
if self._waiters:
|
||||
extra = f"{extra}, waiters:{len(self._waiters)}"
|
||||
return f"<{res[1:-1]} [{extra}]>"
|
||||
|
||||
def locked(self) -> bool:
|
||||
"""Return True if lock is acquired."""
|
||||
return self._locked
|
||||
|
||||
async def acquire(self) -> bool:
|
||||
"""Acquire a lock.
|
||||
|
||||
This method blocks until the lock is unlocked, then sets it to
|
||||
locked and returns True.
|
||||
"""
|
||||
# Implement fair scheduling, where thread always waits
|
||||
# its turn. Jumping the queue if all are cancelled is an optimization.
|
||||
if not self._locked and (
|
||||
self._waiters is None or all(w.cancelled() for w in self._waiters)
|
||||
):
|
||||
self._locked = True
|
||||
return True
|
||||
|
||||
if self._waiters is None:
|
||||
self._waiters = collections.deque()
|
||||
fut = self._get_loop().create_future()
|
||||
self._waiters.append(fut)
|
||||
|
||||
try:
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
self._waiters.remove(fut)
|
||||
except exceptions.CancelledError:
|
||||
# Currently the only exception designed be able to occur here.
|
||||
|
||||
# Ensure the lock invariant: If lock is not claimed (or about
|
||||
# to be claimed by us) and there is a Task in waiters,
|
||||
# ensure that the Task at the head will run.
|
||||
if not self._locked:
|
||||
self._wake_up_first()
|
||||
raise
|
||||
|
||||
# assert self._locked is False
|
||||
self._locked = True
|
||||
return True
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release a lock.
|
||||
|
||||
When the lock is locked, reset it to unlocked, and return.
|
||||
If any other tasks are blocked waiting for the lock to become
|
||||
unlocked, allow exactly one of them to proceed.
|
||||
|
||||
When invoked on an unlocked lock, a RuntimeError is raised.
|
||||
|
||||
There is no return value.
|
||||
"""
|
||||
if self._locked:
|
||||
self._locked = False
|
||||
self._wake_up_first()
|
||||
else:
|
||||
raise RuntimeError("Lock is not acquired.")
|
||||
|
||||
def _wake_up_first(self) -> None:
|
||||
"""Ensure that the first waiter will wake up."""
|
||||
if not self._waiters:
|
||||
return
|
||||
try:
|
||||
fut = next(iter(self._waiters))
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
# .done() means that the waiter is already set to wake up.
|
||||
if not fut.done():
|
||||
fut.set_result(True)
|
||||
|
||||
|
||||
class Condition(_ContextManagerMixin, _LoopBoundMixin):
|
||||
"""Asynchronous equivalent to threading.Condition.
|
||||
|
||||
This class implements condition variable objects. A condition variable
|
||||
allows one or more tasks to wait until they are notified by another
|
||||
task.
|
||||
|
||||
A new Lock object is created and used as the underlying lock.
|
||||
"""
|
||||
|
||||
def __init__(self, lock: Optional[Lock] = None) -> None:
|
||||
if lock is None:
|
||||
lock = Lock()
|
||||
|
||||
self._lock = lock
|
||||
# Export the lock's locked(), acquire() and release() methods.
|
||||
self.locked = lock.locked
|
||||
self.acquire = lock.acquire
|
||||
self.release = lock.release
|
||||
|
||||
self._waiters: collections.deque = collections.deque()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
res = super().__repr__()
|
||||
extra = "locked" if self.locked() else "unlocked"
|
||||
if self._waiters:
|
||||
extra = f"{extra}, waiters:{len(self._waiters)}"
|
||||
return f"<{res[1:-1]} [{extra}]>"
|
||||
|
||||
async def wait(self) -> bool:
|
||||
"""Wait until notified.
|
||||
|
||||
If the calling task has not acquired the lock when this
|
||||
method is called, a RuntimeError is raised.
|
||||
|
||||
This method releases the underlying lock, and then blocks
|
||||
until it is awakened by a notify() or notify_all() call for
|
||||
the same condition variable in another task. Once
|
||||
awakened, it re-acquires the lock and returns True.
|
||||
|
||||
This method may return spuriously,
|
||||
which is why the caller should always
|
||||
re-check the state and be prepared to wait() again.
|
||||
"""
|
||||
if not self.locked():
|
||||
raise RuntimeError("cannot wait on un-acquired lock")
|
||||
|
||||
fut = self._get_loop().create_future()
|
||||
self.release()
|
||||
try:
|
||||
try:
|
||||
self._waiters.append(fut)
|
||||
try:
|
||||
await fut
|
||||
return True
|
||||
finally:
|
||||
self._waiters.remove(fut)
|
||||
|
||||
finally:
|
||||
# Must re-acquire lock even if wait is cancelled.
|
||||
# We only catch CancelledError here, since we don't want any
|
||||
# other (fatal) errors with the future to cause us to spin.
|
||||
err = None
|
||||
while True:
|
||||
try:
|
||||
await self.acquire()
|
||||
break
|
||||
except exceptions.CancelledError as e:
|
||||
err = e
|
||||
|
||||
if err is not None:
|
||||
try:
|
||||
raise err # Re-raise most recent exception instance.
|
||||
finally:
|
||||
err = None # Break reference cycles.
|
||||
except BaseException:
|
||||
# Any error raised out of here _may_ have occurred after this Task
|
||||
# believed to have been successfully notified.
|
||||
# Make sure to notify another Task instead. This may result
|
||||
# in a "spurious wakeup", which is allowed as part of the
|
||||
# Condition Variable protocol.
|
||||
self._notify(1)
|
||||
raise
|
||||
|
||||
async def wait_for(self, predicate: Any) -> Coroutine:
|
||||
"""Wait until a predicate becomes true.
|
||||
|
||||
The predicate should be a callable whose result will be
|
||||
interpreted as a boolean value. The method will repeatedly
|
||||
wait() until it evaluates to true. The final predicate value is
|
||||
the return value.
|
||||
"""
|
||||
result = predicate()
|
||||
while not result:
|
||||
await self.wait()
|
||||
result = predicate()
|
||||
return result
|
||||
|
||||
def notify(self, n: int = 1) -> None:
|
||||
"""By default, wake up one task waiting on this condition, if any.
|
||||
If the calling task has not acquired the lock when this method
|
||||
is called, a RuntimeError is raised.
|
||||
|
||||
This method wakes up n of the tasks waiting for the condition
|
||||
variable; if fewer than n are waiting, they are all awoken.
|
||||
|
||||
Note: an awakened task does not actually return from its
|
||||
wait() call until it can reacquire the lock. Since notify() does
|
||||
not release the lock, its caller should.
|
||||
"""
|
||||
if not self.locked():
|
||||
raise RuntimeError("cannot notify on un-acquired lock")
|
||||
self._notify(n)
|
||||
|
||||
def _notify(self, n: int) -> None:
|
||||
idx = 0
|
||||
for fut in self._waiters:
|
||||
if idx >= n:
|
||||
break
|
||||
|
||||
if not fut.done():
|
||||
idx += 1
|
||||
fut.set_result(False)
|
||||
|
||||
def notify_all(self) -> None:
|
||||
"""Wake up all tasks waiting on this condition. This method acts
|
||||
like notify(), but wakes up all waiting tasks instead of one. If the
|
||||
calling task has not acquired the lock when this method is called,
|
||||
a RuntimeError is raised.
|
||||
"""
|
||||
self.notify(len(self._waiters))
|
||||
49
pymongo/_asyncio_task.py
Normal file
49
pymongo/_asyncio_task.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A custom asyncio.Task that allows checking if a task has been sent a cancellation request.
|
||||
Can be removed once we drop Python 3.10 support in favor of asyncio.Task.cancelling."""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any, Coroutine, Optional
|
||||
|
||||
|
||||
# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered
|
||||
class _Task(asyncio.Task):
|
||||
def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None:
|
||||
super().__init__(coro, name=name)
|
||||
self._cancel_requests = 0
|
||||
asyncio._register_task(self)
|
||||
|
||||
def cancel(self, msg: Optional[str] = None) -> bool:
|
||||
self._cancel_requests += 1
|
||||
return super().cancel(msg=msg)
|
||||
|
||||
def uncancel(self) -> int:
|
||||
if self._cancel_requests > 0:
|
||||
self._cancel_requests -= 1
|
||||
return self._cancel_requests
|
||||
|
||||
def cancelling(self) -> int:
|
||||
return self._cancel_requests
|
||||
|
||||
|
||||
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task:
|
||||
if sys.version_info >= (3, 11):
|
||||
return asyncio.create_task(coro, name=name)
|
||||
return _Task(coro, name=name)
|
||||
@ -476,7 +476,6 @@ class _AsyncClientBulk:
|
||||
if op_type == "delete":
|
||||
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
|
||||
full_result[f"{op_type}Results"][original_index] = res
|
||||
|
||||
except Exception as exc:
|
||||
# Attempt to close the cursor, then raise top-level error.
|
||||
if cmd_cursor.alive:
|
||||
|
||||
@ -45,7 +45,7 @@ from pymongo.common import (
|
||||
)
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.lock import _ALock, _create_lock
|
||||
from pymongo.lock import _async_create_lock
|
||||
from pymongo.message import (
|
||||
_CursorAddress,
|
||||
_GetMore,
|
||||
@ -77,7 +77,7 @@ class _ConnectionManager:
|
||||
def __init__(self, conn: AsyncConnection, more_to_come: bool):
|
||||
self.conn: Optional[AsyncConnection] = conn
|
||||
self.more_to_come = more_to_come
|
||||
self._alock = _ALock(_create_lock())
|
||||
self._lock = _async_create_lock()
|
||||
|
||||
def update_exhaust(self, more_to_come: bool) -> None:
|
||||
self.more_to_come = more_to_come
|
||||
@ -1299,7 +1299,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
|
||||
>>> await cursor.to_list()
|
||||
|
||||
Or, so read at most n items from the cursor::
|
||||
Or, to read at most n items from the cursor::
|
||||
|
||||
>>> await cursor.to_list(n)
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""Support for explicit client-side field level encryption."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import enum
|
||||
import socket
|
||||
@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
|
||||
# BSON encoding/decoding errors are unrelated to encryption so
|
||||
# we should propagate them unchanged.
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise EncryptionError(exc) from exc
|
||||
|
||||
@ -200,6 +203,8 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
conn.close()
|
||||
except (PyMongoError, MongoCryptError):
|
||||
raise # Propagate pymongo errors directly.
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as error:
|
||||
# Wrap I/O errors in PyMongo exceptions.
|
||||
_raise_connection_failure((host, port), error)
|
||||
@ -722,6 +727,8 @@ class AsyncClientEncryption(Generic[_DocumentType]):
|
||||
await database.create_collection(name=name, **kwargs),
|
||||
encrypted_fields,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise EncryptedCollectionError(exc, encrypted_fields) from exc
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ access:
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import warnings
|
||||
@ -59,8 +60,8 @@ from typing import (
|
||||
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot, common, helpers_shared, uri_parser
|
||||
from pymongo.asynchronous import client_session, database, periodic_executor
|
||||
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
|
||||
from pymongo.asynchronous import client_session, database
|
||||
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
|
||||
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
|
||||
from pymongo.asynchronous.client_session import _EmptyServerSession
|
||||
@ -82,7 +83,11 @@ from pymongo.errors import (
|
||||
WaitQueueTimeoutError,
|
||||
WriteConcernError,
|
||||
)
|
||||
from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks
|
||||
from pymongo.lock import (
|
||||
_HAS_REGISTER_AT_FORK,
|
||||
_async_create_lock,
|
||||
_release_locks,
|
||||
)
|
||||
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
|
||||
from pymongo.message import _CursorAddress, _GetMore, _Query
|
||||
from pymongo.monitoring import ConnectionClosedReason
|
||||
@ -842,7 +847,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
|
||||
|
||||
self._default_database_name = dbase
|
||||
self._lock = _ALock(_create_lock())
|
||||
self._lock = _async_create_lock()
|
||||
self._kill_cursors_queue: list = []
|
||||
|
||||
self._event_listeners = options.pool_options._event_listeners
|
||||
@ -908,7 +913,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
await AsyncMongoClient._process_periodic_tasks(client)
|
||||
return True
|
||||
|
||||
executor = periodic_executor.PeriodicExecutor(
|
||||
executor = periodic_executor.AsyncPeriodicExecutor(
|
||||
interval=common.KILL_CURSOR_FREQUENCY,
|
||||
min_interval=common.MIN_HEARTBEAT_INTERVAL,
|
||||
target=target,
|
||||
@ -1722,7 +1727,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
address=address,
|
||||
)
|
||||
|
||||
async with operation.conn_mgr._alock:
|
||||
async with operation.conn_mgr._lock:
|
||||
async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
|
||||
err_handler.contribute_socket(operation.conn_mgr.conn)
|
||||
return await server.run_operation(
|
||||
@ -1970,7 +1975,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
try:
|
||||
if conn_mgr:
|
||||
async with conn_mgr._alock:
|
||||
async with conn_mgr._lock:
|
||||
# Cursor is pinned to LB outside of a transaction.
|
||||
assert address is not None
|
||||
assert conn_mgr.conn is not None
|
||||
@ -2033,6 +2038,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
for address, cursor_id, conn_mgr in pinned_cursors:
|
||||
try:
|
||||
await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
# Raise the exception when client is closed so that it
|
||||
@ -2047,6 +2054,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
for address, cursor_ids in address_to_cursor_ids.items():
|
||||
try:
|
||||
await self._kill_cursors(cursor_ids, address, topology, session=None)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
raise
|
||||
@ -2061,6 +2070,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
try:
|
||||
await self._process_kill_cursors()
|
||||
await self._topology.update_pool()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
return
|
||||
|
||||
@ -16,20 +16,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import time
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
|
||||
|
||||
from pymongo import common
|
||||
from pymongo import common, periodic_executor
|
||||
from pymongo._csot import MovingMinimum
|
||||
from pymongo.asynchronous import periodic_executor
|
||||
from pymongo.asynchronous.periodic_executor import _shutdown_executors
|
||||
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.lock import _async_create_lock
|
||||
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
|
||||
from pymongo.periodic_executor import _shutdown_executors
|
||||
from pymongo.pool_options import _is_faas
|
||||
from pymongo.read_preferences import MovingAverage
|
||||
from pymongo.server_description import ServerDescription
|
||||
@ -76,7 +76,7 @@ class MonitorBase:
|
||||
await monitor._run() # type:ignore[attr-defined]
|
||||
return True
|
||||
|
||||
executor = periodic_executor.PeriodicExecutor(
|
||||
executor = periodic_executor.AsyncPeriodicExecutor(
|
||||
interval=interval, min_interval=min_interval, target=target, name=name
|
||||
)
|
||||
|
||||
@ -112,9 +112,9 @@ class MonitorBase:
|
||||
"""
|
||||
self.gc_safe_close()
|
||||
|
||||
def join(self, timeout: Optional[int] = None) -> None:
|
||||
async def join(self, timeout: Optional[int] = None) -> None:
|
||||
"""Wait for the monitor to stop."""
|
||||
self._executor.join(timeout)
|
||||
await self._executor.join(timeout)
|
||||
|
||||
def request_check(self) -> None:
|
||||
"""If the monitor is sleeping, wake it soon."""
|
||||
@ -139,7 +139,7 @@ class Monitor(MonitorBase):
|
||||
"""
|
||||
super().__init__(
|
||||
topology,
|
||||
"pymongo_server_monitor_thread",
|
||||
"pymongo_server_monitor_task",
|
||||
topology_settings.heartbeat_frequency,
|
||||
common.MIN_HEARTBEAT_INTERVAL,
|
||||
)
|
||||
@ -238,6 +238,9 @@ class Monitor(MonitorBase):
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
await self.close()
|
||||
finally:
|
||||
if self._executor._stopped:
|
||||
await self._rtt_monitor.close()
|
||||
|
||||
async def _check_server(self) -> ServerDescription:
|
||||
"""Call hello or read the next streaming response.
|
||||
@ -252,8 +255,10 @@ class Monitor(MonitorBase):
|
||||
except (OperationFailure, NotPrimaryError) as exc:
|
||||
# Update max cluster time even when hello fails.
|
||||
details = cast(Mapping[str, Any], exc.details)
|
||||
self._topology.receive_cluster_time(details.get("$clusterTime"))
|
||||
await self._topology.receive_cluster_time(details.get("$clusterTime"))
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except ReferenceError:
|
||||
raise
|
||||
except Exception as error:
|
||||
@ -280,7 +285,7 @@ class Monitor(MonitorBase):
|
||||
await self._reset_connection()
|
||||
if isinstance(error, _OperationCancelled):
|
||||
raise
|
||||
self._rtt_monitor.reset()
|
||||
await self._rtt_monitor.reset()
|
||||
# Server type defaults to Unknown.
|
||||
return ServerDescription(address, error=error)
|
||||
|
||||
@ -321,9 +326,9 @@ class Monitor(MonitorBase):
|
||||
self._conn_id = conn.id
|
||||
response, round_trip_time = await self._check_with_socket(conn)
|
||||
if not response.awaitable:
|
||||
self._rtt_monitor.add_sample(round_trip_time)
|
||||
await self._rtt_monitor.add_sample(round_trip_time)
|
||||
|
||||
avg_rtt, min_rtt = self._rtt_monitor.get()
|
||||
avg_rtt, min_rtt = await self._rtt_monitor.get()
|
||||
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
|
||||
if self._publish:
|
||||
assert self._listeners is not None
|
||||
@ -419,6 +424,8 @@ class SrvMonitor(MonitorBase):
|
||||
if len(seedlist) == 0:
|
||||
# As per the spec: this should be treated as a failure.
|
||||
raise Exception
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
# As per the spec, upon encountering an error:
|
||||
# - An error must not be raised
|
||||
@ -439,7 +446,7 @@ class _RttMonitor(MonitorBase):
|
||||
"""
|
||||
super().__init__(
|
||||
topology,
|
||||
"pymongo_server_rtt_thread",
|
||||
"pymongo_server_rtt_task",
|
||||
topology_settings.heartbeat_frequency,
|
||||
common.MIN_HEARTBEAT_INTERVAL,
|
||||
)
|
||||
@ -447,7 +454,7 @@ class _RttMonitor(MonitorBase):
|
||||
self._pool = pool
|
||||
self._moving_average = MovingAverage()
|
||||
self._moving_min = MovingMinimum()
|
||||
self._lock = _create_lock()
|
||||
self._lock = _async_create_lock()
|
||||
|
||||
async def close(self) -> None:
|
||||
self.gc_safe_close()
|
||||
@ -455,20 +462,20 @@ class _RttMonitor(MonitorBase):
|
||||
# thread has the socket checked out, it will be closed when checked in.
|
||||
await self._pool.reset()
|
||||
|
||||
def add_sample(self, sample: float) -> None:
|
||||
async def add_sample(self, sample: float) -> None:
|
||||
"""Add a RTT sample."""
|
||||
with self._lock:
|
||||
async with self._lock:
|
||||
self._moving_average.add_sample(sample)
|
||||
self._moving_min.add_sample(sample)
|
||||
|
||||
def get(self) -> tuple[Optional[float], float]:
|
||||
async def get(self) -> tuple[Optional[float], float]:
|
||||
"""Get the calculated average, or None if no samples yet and the min."""
|
||||
with self._lock:
|
||||
async with self._lock:
|
||||
return self._moving_average.get(), self._moving_min.get()
|
||||
|
||||
def reset(self) -> None:
|
||||
async def reset(self) -> None:
|
||||
"""Reset the average RTT."""
|
||||
with self._lock:
|
||||
async with self._lock:
|
||||
self._moving_average.reset()
|
||||
self._moving_min.reset()
|
||||
|
||||
@ -478,10 +485,12 @@ class _RttMonitor(MonitorBase):
|
||||
# heartbeat protocol (MongoDB 4.4+).
|
||||
# XXX: Skip check if the server is unknown?
|
||||
rtt = await self._ping()
|
||||
self.add_sample(rtt)
|
||||
await self.add_sample(rtt)
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
await self.close()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
await self._pool.reset()
|
||||
|
||||
@ -536,4 +545,5 @@ def _shutdown_resources() -> None:
|
||||
shutdown()
|
||||
|
||||
|
||||
atexit.register(_shutdown_resources)
|
||||
if _IS_SYNC:
|
||||
atexit.register(_shutdown_resources)
|
||||
|
||||
@ -1,219 +0,0 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Run a target function on a background thread."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Optional
|
||||
|
||||
from pymongo.lock import _ALock, _create_lock
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class PeriodicExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
interval: float,
|
||||
min_interval: float,
|
||||
target: Any,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Run a target function periodically on a background thread.
|
||||
|
||||
If the target's return value is false, the executor stops.
|
||||
|
||||
:param interval: Seconds between calls to `target`.
|
||||
:param min_interval: Minimum seconds between calls if `wake` is
|
||||
called very often.
|
||||
:param target: A function.
|
||||
:param name: A name to give the underlying thread.
|
||||
"""
|
||||
# threading.Event and its internal condition variable are expensive
|
||||
# in Python 2, see PYTHON-983. Use a boolean to know when to wake.
|
||||
# The executor's design is constrained by several Python issues, see
|
||||
# "periodic_executor.rst" in this repository.
|
||||
self._event = False
|
||||
self._interval = interval
|
||||
self._min_interval = min_interval
|
||||
self._target = target
|
||||
self._stopped = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._name = name
|
||||
self._skip_sleep = False
|
||||
self._thread_will_exit = False
|
||||
self._lock = _ALock(_create_lock())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
|
||||
|
||||
def _run_async(self) -> None:
|
||||
# The default asyncio loop implementation on Windows
|
||||
# has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240)
|
||||
# We explicitly use a different loop implementation here to prevent that issue
|
||||
if sys.platform == "win32":
|
||||
loop = asyncio.SelectorEventLoop()
|
||||
try:
|
||||
loop.run_until_complete(self._run()) # type: ignore[func-returns-value]
|
||||
finally:
|
||||
loop.close()
|
||||
else:
|
||||
asyncio.run(self._run()) # type: ignore[func-returns-value]
|
||||
|
||||
def open(self) -> None:
|
||||
"""Start. Multiple calls have no effect.
|
||||
|
||||
Not safe to call from multiple threads at once.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._thread_will_exit:
|
||||
# If the background thread has read self._stopped as True
|
||||
# there is a chance that it has not yet exited. The call to
|
||||
# join should not block indefinitely because there is no
|
||||
# other work done outside the while loop in self._run.
|
||||
try:
|
||||
assert self._thread is not None
|
||||
self._thread.join()
|
||||
except ReferenceError:
|
||||
# Thread terminated.
|
||||
pass
|
||||
self._thread_will_exit = False
|
||||
self._stopped = False
|
||||
started: Any = False
|
||||
try:
|
||||
started = self._thread and self._thread.is_alive()
|
||||
except ReferenceError:
|
||||
# Thread terminated.
|
||||
pass
|
||||
|
||||
if not started:
|
||||
if _IS_SYNC:
|
||||
thread = threading.Thread(target=self._run, name=self._name)
|
||||
else:
|
||||
thread = threading.Thread(target=self._run_async, name=self._name)
|
||||
thread.daemon = True
|
||||
self._thread = weakref.proxy(thread)
|
||||
_register_executor(self)
|
||||
# Mitigation to RuntimeError firing when thread starts on shutdown
|
||||
# https://github.com/python/cpython/issues/114570
|
||||
try:
|
||||
thread.start()
|
||||
except RuntimeError as e:
|
||||
if "interpreter shutdown" in str(e) or sys.is_finalizing():
|
||||
self._thread = None
|
||||
return
|
||||
raise
|
||||
|
||||
def close(self, dummy: Any = None) -> None:
|
||||
"""Stop. To restart, call open().
|
||||
|
||||
The dummy parameter allows an executor's close method to be a weakref
|
||||
callback; see monitor.py.
|
||||
"""
|
||||
self._stopped = True
|
||||
|
||||
def join(self, timeout: Optional[int] = None) -> None:
|
||||
if self._thread is not None:
|
||||
try:
|
||||
self._thread.join(timeout)
|
||||
except (ReferenceError, RuntimeError):
|
||||
# Thread already terminated, or not yet started.
|
||||
pass
|
||||
|
||||
def wake(self) -> None:
|
||||
"""Execute the target function soon."""
|
||||
self._event = True
|
||||
|
||||
def update_interval(self, new_interval: int) -> None:
|
||||
self._interval = new_interval
|
||||
|
||||
def skip_sleep(self) -> None:
|
||||
self._skip_sleep = True
|
||||
|
||||
async def _should_stop(self) -> bool:
|
||||
async with self._lock:
|
||||
if self._stopped:
|
||||
self._thread_will_exit = True
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _run(self) -> None:
|
||||
while not await self._should_stop():
|
||||
try:
|
||||
if not await self._target():
|
||||
self._stopped = True
|
||||
break
|
||||
except BaseException:
|
||||
async with self._lock:
|
||||
self._stopped = True
|
||||
self._thread_will_exit = True
|
||||
|
||||
raise
|
||||
|
||||
if self._skip_sleep:
|
||||
self._skip_sleep = False
|
||||
else:
|
||||
deadline = time.monotonic() + self._interval
|
||||
while not self._stopped and time.monotonic() < deadline:
|
||||
await asyncio.sleep(self._min_interval)
|
||||
if self._event:
|
||||
break # Early wake.
|
||||
|
||||
self._event = False
|
||||
|
||||
|
||||
# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started,
|
||||
# an executor is kept alive by a strong reference from its thread and perhaps
|
||||
# from other objects. When the thread dies and all other referrers are freed,
|
||||
# the executor is freed and removed from _EXECUTORS. If any threads are
|
||||
# running when the interpreter begins to shut down, we try to halt and join
|
||||
# them to avoid spurious errors.
|
||||
_EXECUTORS = set()
|
||||
|
||||
|
||||
def _register_executor(executor: PeriodicExecutor) -> None:
|
||||
ref = weakref.ref(executor, _on_executor_deleted)
|
||||
_EXECUTORS.add(ref)
|
||||
|
||||
|
||||
def _on_executor_deleted(ref: weakref.ReferenceType[PeriodicExecutor]) -> None:
|
||||
_EXECUTORS.remove(ref)
|
||||
|
||||
|
||||
def _shutdown_executors() -> None:
|
||||
if _EXECUTORS is None:
|
||||
return
|
||||
|
||||
# Copy the set. Stopping threads has the side effect of removing executors.
|
||||
executors = list(_EXECUTORS)
|
||||
|
||||
# First signal all executors to close...
|
||||
for ref in executors:
|
||||
executor = ref()
|
||||
if executor:
|
||||
executor.close()
|
||||
|
||||
# ...then try to join them.
|
||||
for ref in executors:
|
||||
executor = ref()
|
||||
if executor:
|
||||
executor.join(1)
|
||||
|
||||
executor = None
|
||||
@ -23,7 +23,6 @@ import os
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import (
|
||||
@ -65,7 +64,11 @@ from pymongo.errors import ( # type:ignore[attr-defined]
|
||||
_CertificateError,
|
||||
)
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.lock import _ACondition, _ALock, _create_lock
|
||||
from pymongo.lock import (
|
||||
_async_cond_wait,
|
||||
_async_create_condition,
|
||||
_async_create_lock,
|
||||
)
|
||||
from pymongo.logger import (
|
||||
_CONNECTION_LOGGER,
|
||||
_ConnectionStatusMessage,
|
||||
@ -208,11 +211,6 @@ def _raise_connection_failure(
|
||||
raise AutoReconnect(msg) from error
|
||||
|
||||
|
||||
async def _cond_wait(condition: _ACondition, deadline: Optional[float]) -> bool:
|
||||
timeout = deadline - time.monotonic() if deadline else None
|
||||
return await condition.wait(timeout)
|
||||
|
||||
|
||||
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
|
||||
details = {}
|
||||
timeout = _csot.get_timeout()
|
||||
@ -706,6 +704,8 @@ class AsyncConnection:
|
||||
# shutdown.
|
||||
try:
|
||||
self.conn.close()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
@ -992,8 +992,8 @@ class Pool:
|
||||
# from the right side.
|
||||
self.conns: collections.deque = collections.deque()
|
||||
self.active_contexts: set[_CancellationContext] = set()
|
||||
_lock = _create_lock()
|
||||
self.lock = _ALock(_lock)
|
||||
self.lock = _async_create_lock()
|
||||
self._max_connecting_cond = _async_create_condition(self.lock)
|
||||
self.active_sockets = 0
|
||||
# Monotonically increasing connection ID required for CMAP Events.
|
||||
self.next_connection_id = 1
|
||||
@ -1019,7 +1019,7 @@ class Pool:
|
||||
# The first portion of the wait queue.
|
||||
# Enforces: maxPoolSize
|
||||
# Also used for: clearing the wait queue
|
||||
self.size_cond = _ACondition(threading.Condition(_lock))
|
||||
self.size_cond = _async_create_condition(self.lock)
|
||||
self.requests = 0
|
||||
self.max_pool_size = self.opts.max_pool_size
|
||||
if not self.max_pool_size:
|
||||
@ -1027,7 +1027,7 @@ class Pool:
|
||||
# The second portion of the wait queue.
|
||||
# Enforces: maxConnecting
|
||||
# Also used for: clearing the wait queue
|
||||
self._max_connecting_cond = _ACondition(threading.Condition(_lock))
|
||||
self._max_connecting_cond = _async_create_condition(self.lock)
|
||||
self._max_connecting = self.opts.max_connecting
|
||||
self._pending = 0
|
||||
self._client_id = client_id
|
||||
@ -1466,7 +1466,8 @@ class Pool:
|
||||
async with self.size_cond:
|
||||
self._raise_if_not_ready(checkout_started_time, emit_event=True)
|
||||
while not (self.requests < self.max_pool_size):
|
||||
if not await _cond_wait(self.size_cond, deadline):
|
||||
timeout = deadline - time.monotonic() if deadline else None
|
||||
if not await _async_cond_wait(self.size_cond, timeout):
|
||||
# Timed out, notify the next thread to ensure a
|
||||
# timeout doesn't consume the condition.
|
||||
if self.requests < self.max_pool_size:
|
||||
@ -1489,7 +1490,8 @@ class Pool:
|
||||
async with self._max_connecting_cond:
|
||||
self._raise_if_not_ready(checkout_started_time, emit_event=False)
|
||||
while not (self.conns or self._pending < self._max_connecting):
|
||||
if not await _cond_wait(self._max_connecting_cond, deadline):
|
||||
timeout = deadline - time.monotonic() if deadline else None
|
||||
if not await _async_cond_wait(self._max_connecting_cond, timeout):
|
||||
# Timed out, notify the next thread to ensure a
|
||||
# timeout doesn't consume the condition.
|
||||
if self.conns or self._pending < self._max_connecting:
|
||||
|
||||
@ -27,8 +27,7 @@ import weakref
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
|
||||
|
||||
from pymongo import _csot, common, helpers_shared
|
||||
from pymongo.asynchronous import periodic_executor
|
||||
from pymongo import _csot, common, helpers_shared, periodic_executor
|
||||
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
|
||||
from pymongo.asynchronous.monitor import SrvMonitor
|
||||
from pymongo.asynchronous.pool import Pool
|
||||
@ -44,7 +43,11 @@ from pymongo.errors import (
|
||||
WriteError,
|
||||
)
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _ACondition, _ALock, _create_lock
|
||||
from pymongo.lock import (
|
||||
_async_cond_wait,
|
||||
_async_create_condition,
|
||||
_async_create_lock,
|
||||
)
|
||||
from pymongo.logger import (
|
||||
_SDAM_LOGGER,
|
||||
_SERVER_SELECTION_LOGGER,
|
||||
@ -170,9 +173,10 @@ class Topology:
|
||||
self._seed_addresses = list(topology_description.server_descriptions())
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
_lock = _create_lock()
|
||||
self._lock = _ALock(_lock)
|
||||
self._condition = _ACondition(self._settings.condition_class(_lock))
|
||||
self._lock = _async_create_lock()
|
||||
self._condition = _async_create_condition(
|
||||
self._lock, self._settings.condition_class if _IS_SYNC else None
|
||||
)
|
||||
self._servers: dict[_Address, Server] = {}
|
||||
self._pid: Optional[int] = None
|
||||
self._max_cluster_time: Optional[ClusterTime] = None
|
||||
@ -185,7 +189,7 @@ class Topology:
|
||||
async def target() -> bool:
|
||||
return process_events_queue(weak)
|
||||
|
||||
executor = periodic_executor.PeriodicExecutor(
|
||||
executor = periodic_executor.AsyncPeriodicExecutor(
|
||||
interval=common.EVENTS_QUEUE_FREQUENCY,
|
||||
min_interval=common.MIN_HEARTBEAT_INTERVAL,
|
||||
target=target,
|
||||
@ -354,7 +358,7 @@ class Topology:
|
||||
# change, or for a timeout. We won't miss any changes that
|
||||
# came after our most recent apply_selector call, since we've
|
||||
# held the lock until now.
|
||||
await self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
|
||||
await _async_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL)
|
||||
self._description.check_compatible()
|
||||
now = time.monotonic()
|
||||
server_descriptions = self._description.apply_selector(
|
||||
@ -654,7 +658,7 @@ class Topology:
|
||||
"""Wake all monitors, wait for at least one to check its server."""
|
||||
async with self._lock:
|
||||
self._request_check_all()
|
||||
await self._condition.wait(wait_time)
|
||||
await _async_cond_wait(self._condition, wait_time)
|
||||
|
||||
def data_bearing_servers(self) -> list[ServerDescription]:
|
||||
"""Return a list of all data-bearing servers.
|
||||
@ -742,7 +746,7 @@ class Topology:
|
||||
if self._publish_server or self._publish_tp:
|
||||
# Make sure the events executor thread is fully closed before publishing the remaining events
|
||||
self.__events_executor.close()
|
||||
self.__events_executor.join(1)
|
||||
await self.__events_executor.join(1)
|
||||
process_events_queue(weakref.ref(self._events)) # type: ignore[arg-type]
|
||||
|
||||
@property
|
||||
|
||||
245
pymongo/lock.py
245
pymongo/lock.py
@ -11,15 +11,20 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Internal helpers for lock and condition coordination primitives."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from asyncio import wait_for
|
||||
from typing import Any, Optional, TypeVar
|
||||
|
||||
import pymongo._asyncio_lock
|
||||
|
||||
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
|
||||
|
||||
@ -28,6 +33,15 @@ _forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
# Needed to support 3.13 asyncio fixes (https://github.com/python/cpython/issues/112202)
|
||||
# in older versions of Python
|
||||
if sys.version_info >= (3, 13):
|
||||
Lock = asyncio.Lock
|
||||
Condition = asyncio.Condition
|
||||
else:
|
||||
Lock = pymongo._asyncio_lock.Lock
|
||||
Condition = pymongo._asyncio_lock.Condition
|
||||
|
||||
|
||||
def _create_lock() -> threading.Lock:
|
||||
"""Represents a lock that is tracked upon instantiation using a WeakSet and
|
||||
@ -39,6 +53,27 @@ def _create_lock() -> threading.Lock:
|
||||
return lock
|
||||
|
||||
|
||||
def _async_create_lock() -> Lock:
|
||||
"""Represents an asyncio.Lock."""
|
||||
return Lock()
|
||||
|
||||
|
||||
def _create_condition(
|
||||
lock: threading.Lock, condition_class: Optional[Any] = None
|
||||
) -> threading.Condition:
|
||||
"""Represents a threading.Condition."""
|
||||
if condition_class:
|
||||
return condition_class(lock)
|
||||
return threading.Condition(lock)
|
||||
|
||||
|
||||
def _async_create_condition(lock: Lock, condition_class: Optional[Any] = None) -> Condition:
|
||||
"""Represents an asyncio.Condition."""
|
||||
if condition_class:
|
||||
return condition_class(lock)
|
||||
return Condition(lock)
|
||||
|
||||
|
||||
def _release_locks() -> None:
|
||||
# Completed the fork, reset all the locks in the child.
|
||||
for lock in _forkable_locks:
|
||||
@ -46,202 +81,12 @@ def _release_locks() -> None:
|
||||
lock.release()
|
||||
|
||||
|
||||
# Needed only for synchro.py compat.
|
||||
def _Lock(lock: threading.Lock) -> threading.Lock:
|
||||
return lock
|
||||
async def _async_cond_wait(condition: Condition, timeout: Optional[float]) -> bool:
|
||||
try:
|
||||
return await wait_for(condition.wait(), timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
|
||||
|
||||
class _ALock:
|
||||
__slots__ = ("_lock",)
|
||||
|
||||
def __init__(self, lock: threading.Lock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
|
||||
return self._lock.acquire(blocking=blocking, timeout=timeout)
|
||||
|
||||
async def a_acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
|
||||
if timeout > 0:
|
||||
tstart = time.monotonic()
|
||||
while True:
|
||||
acquired = self._lock.acquire(blocking=False)
|
||||
if acquired:
|
||||
return True
|
||||
if timeout > 0 and (time.monotonic() - tstart) > timeout:
|
||||
return False
|
||||
if not blocking:
|
||||
return False
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def release(self) -> None:
|
||||
self._lock.release()
|
||||
|
||||
async def __aenter__(self) -> _ALock:
|
||||
await self.a_acquire()
|
||||
return self
|
||||
|
||||
def __enter__(self) -> _ALock:
|
||||
self._lock.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.release()
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.release()
|
||||
|
||||
|
||||
def _safe_set_result(fut: asyncio.Future) -> None:
|
||||
# Ensure the future hasn't been cancelled before calling set_result.
|
||||
if not fut.done():
|
||||
fut.set_result(False)
|
||||
|
||||
|
||||
class _ACondition:
|
||||
__slots__ = ("_condition", "_waiters")
|
||||
|
||||
def __init__(self, condition: threading.Condition) -> None:
|
||||
self._condition = condition
|
||||
self._waiters: collections.deque = collections.deque()
|
||||
|
||||
async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
|
||||
if timeout > 0:
|
||||
tstart = time.monotonic()
|
||||
while True:
|
||||
acquired = self._condition.acquire(blocking=False)
|
||||
if acquired:
|
||||
return True
|
||||
if timeout > 0 and (time.monotonic() - tstart) > timeout:
|
||||
return False
|
||||
if not blocking:
|
||||
return False
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def wait(self, timeout: Optional[float] = None) -> bool:
|
||||
"""Wait until notified.
|
||||
|
||||
If the calling task has not acquired the lock when this
|
||||
method is called, a RuntimeError is raised.
|
||||
|
||||
This method releases the underlying lock, and then blocks
|
||||
until it is awakened by a notify() or notify_all() call for
|
||||
the same condition variable in another task. Once
|
||||
awakened, it re-acquires the lock and returns True.
|
||||
|
||||
This method may return spuriously,
|
||||
which is why the caller should always
|
||||
re-check the state and be prepared to wait() again.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
fut = loop.create_future()
|
||||
self._waiters.append((loop, fut))
|
||||
self.release()
|
||||
try:
|
||||
try:
|
||||
try:
|
||||
await asyncio.wait_for(fut, timeout)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
return False # Return false on timeout for sync pool compat.
|
||||
finally:
|
||||
# Must re-acquire lock even if wait is cancelled.
|
||||
# We only catch CancelledError here, since we don't want any
|
||||
# other (fatal) errors with the future to cause us to spin.
|
||||
err = None
|
||||
while True:
|
||||
try:
|
||||
await self.acquire()
|
||||
break
|
||||
except asyncio.exceptions.CancelledError as e:
|
||||
err = e
|
||||
|
||||
self._waiters.remove((loop, fut))
|
||||
if err is not None:
|
||||
try:
|
||||
raise err # Re-raise most recent exception instance.
|
||||
finally:
|
||||
err = None # Break reference cycles.
|
||||
except BaseException:
|
||||
# Any error raised out of here _may_ have occurred after this Task
|
||||
# believed to have been successfully notified.
|
||||
# Make sure to notify another Task instead. This may result
|
||||
# in a "spurious wakeup", which is allowed as part of the
|
||||
# Condition Variable protocol.
|
||||
self.notify(1)
|
||||
raise
|
||||
|
||||
async def wait_for(self, predicate: Callable[[], _T]) -> _T:
|
||||
"""Wait until a predicate becomes true.
|
||||
|
||||
The predicate should be a callable whose result will be
|
||||
interpreted as a boolean value. The method will repeatedly
|
||||
wait() until it evaluates to true. The final predicate value is
|
||||
the return value.
|
||||
"""
|
||||
result = predicate()
|
||||
while not result:
|
||||
await self.wait()
|
||||
result = predicate()
|
||||
return result
|
||||
|
||||
def notify(self, n: int = 1) -> None:
|
||||
"""By default, wake up one coroutine waiting on this condition, if any.
|
||||
If the calling coroutine has not acquired the lock when this method
|
||||
is called, a RuntimeError is raised.
|
||||
|
||||
This method wakes up at most n of the coroutines waiting for the
|
||||
condition variable; it is a no-op if no coroutines are waiting.
|
||||
|
||||
Note: an awakened coroutine does not actually return from its
|
||||
wait() call until it can reacquire the lock. Since notify() does
|
||||
not release the lock, its caller should.
|
||||
"""
|
||||
idx = 0
|
||||
to_remove = []
|
||||
for loop, fut in self._waiters:
|
||||
if idx >= n:
|
||||
break
|
||||
|
||||
if fut.done():
|
||||
continue
|
||||
|
||||
try:
|
||||
loop.call_soon_threadsafe(_safe_set_result, fut)
|
||||
except RuntimeError:
|
||||
# Loop was closed, ignore.
|
||||
to_remove.append((loop, fut))
|
||||
continue
|
||||
|
||||
idx += 1
|
||||
|
||||
for waiter in to_remove:
|
||||
self._waiters.remove(waiter)
|
||||
|
||||
def notify_all(self) -> None:
|
||||
"""Wake up all threads waiting on this condition. This method acts
|
||||
like notify(), but wakes up all waiting threads instead of one. If the
|
||||
calling thread has not acquired the lock when this method is called,
|
||||
a RuntimeError is raised.
|
||||
"""
|
||||
self.notify(len(self._waiters))
|
||||
|
||||
def locked(self) -> bool:
|
||||
"""Only needed for tests in test_locks."""
|
||||
return self._condition._lock.locked() # type: ignore[attr-defined]
|
||||
|
||||
def release(self) -> None:
|
||||
self._condition.release()
|
||||
|
||||
async def __aenter__(self) -> _ACondition:
|
||||
await self.acquire()
|
||||
return self
|
||||
|
||||
def __enter__(self) -> _ACondition:
|
||||
self._condition.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.release()
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.release()
|
||||
def _cond_wait(condition: threading.Condition, timeout: Optional[float]) -> bool:
|
||||
return condition.wait(timeout)
|
||||
|
||||
@ -28,7 +28,8 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from pymongo import _csot, ssl_support
|
||||
from pymongo import ssl_support
|
||||
from pymongo._asyncio_task import create_task
|
||||
from pymongo.errors import _OperationCancelled
|
||||
from pymongo.socket_checker import _errno_from_exception
|
||||
|
||||
@ -259,19 +260,20 @@ async def async_receive_data(
|
||||
|
||||
sock.settimeout(0.0)
|
||||
loop = asyncio.get_event_loop()
|
||||
cancellation_task = asyncio.create_task(_poll_cancellation(conn))
|
||||
cancellation_task = create_task(_poll_cancellation(conn))
|
||||
try:
|
||||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
||||
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
|
||||
read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
|
||||
else:
|
||||
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
|
||||
read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
|
||||
tasks = [read_task, cancellation_task]
|
||||
done, pending = await asyncio.wait(
|
||||
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
await asyncio.wait(pending)
|
||||
if pending:
|
||||
await asyncio.wait(pending)
|
||||
if len(done) == 0:
|
||||
raise socket.timeout("timed out")
|
||||
if read_task in done:
|
||||
@ -314,62 +316,47 @@ async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLo
|
||||
return mv
|
||||
|
||||
|
||||
# Sync version:
|
||||
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
|
||||
"""Block until at least one byte is read, or a timeout, or a cancel."""
|
||||
sock = conn.conn
|
||||
timed_out = False
|
||||
# Check if the connection's socket has been manually closed
|
||||
if sock.fileno() == -1:
|
||||
return
|
||||
while True:
|
||||
# SSLSocket can have buffered data which won't be caught by select.
|
||||
if hasattr(sock, "pending") and sock.pending() > 0:
|
||||
readable = True
|
||||
else:
|
||||
# Wait up to 500ms for the socket to become readable and then
|
||||
# check for cancellation.
|
||||
if deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
# When the timeout has expired perform one final check to
|
||||
# see if the socket is readable. This helps avoid spurious
|
||||
# timeouts on AWS Lambda and other FaaS environments.
|
||||
if remaining <= 0:
|
||||
timed_out = True
|
||||
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
|
||||
else:
|
||||
timeout = _POLL_TIMEOUT
|
||||
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
|
||||
if conn.cancel_context.cancelled:
|
||||
raise _OperationCancelled("operation cancelled")
|
||||
if readable:
|
||||
return
|
||||
if timed_out:
|
||||
raise socket.timeout("timed out")
|
||||
|
||||
|
||||
def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
|
||||
buf = bytearray(length)
|
||||
mv = memoryview(buf)
|
||||
bytes_read = 0
|
||||
while bytes_read < length:
|
||||
try:
|
||||
wait_for_read(conn, deadline)
|
||||
# CSOT: Update timeout. When the timeout has expired perform one
|
||||
# final non-blocking recv. This helps avoid spurious timeouts when
|
||||
# the response is actually already buffered on the client.
|
||||
if _csot.get_timeout() and deadline is not None:
|
||||
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
|
||||
chunk_length = conn.conn.recv_into(mv[bytes_read:])
|
||||
except BLOCKING_IO_ERRORS:
|
||||
raise socket.timeout("timed out") from None
|
||||
except OSError as exc:
|
||||
if _errno_from_exception(exc) == errno.EINTR:
|
||||
# To support cancelling a network read, we shorten the socket timeout and
|
||||
# check for the cancellation signal after each timeout. Alternatively we
|
||||
# could close the socket but that does not reliably cancel recv() calls
|
||||
# on all OSes.
|
||||
orig_timeout = conn.conn.gettimeout()
|
||||
try:
|
||||
while bytes_read < length:
|
||||
if deadline is not None:
|
||||
# CSOT: Update timeout. When the timeout has expired perform one
|
||||
# final non-blocking recv. This helps avoid spurious timeouts when
|
||||
# the response is actually already buffered on the client.
|
||||
short_timeout = min(max(deadline - time.monotonic(), 0), _POLL_TIMEOUT)
|
||||
else:
|
||||
short_timeout = _POLL_TIMEOUT
|
||||
conn.set_conn_timeout(short_timeout)
|
||||
try:
|
||||
chunk_length = conn.conn.recv_into(mv[bytes_read:])
|
||||
except BLOCKING_IO_ERRORS:
|
||||
if conn.cancel_context.cancelled:
|
||||
raise _OperationCancelled("operation cancelled") from None
|
||||
# We reached the true deadline.
|
||||
raise socket.timeout("timed out") from None
|
||||
except socket.timeout:
|
||||
if conn.cancel_context.cancelled:
|
||||
raise _OperationCancelled("operation cancelled") from None
|
||||
continue
|
||||
raise
|
||||
if chunk_length == 0:
|
||||
raise OSError("connection closed")
|
||||
except OSError as exc:
|
||||
if conn.cancel_context.cancelled:
|
||||
raise _OperationCancelled("operation cancelled") from None
|
||||
if _errno_from_exception(exc) == errno.EINTR:
|
||||
continue
|
||||
raise
|
||||
if chunk_length == 0:
|
||||
raise OSError("connection closed")
|
||||
|
||||
bytes_read += chunk_length
|
||||
bytes_read += chunk_length
|
||||
finally:
|
||||
conn.set_conn_timeout(orig_timeout)
|
||||
|
||||
return mv
|
||||
|
||||
@ -23,9 +23,102 @@ import time
|
||||
import weakref
|
||||
from typing import Any, Optional
|
||||
|
||||
from pymongo._asyncio_task import create_task
|
||||
from pymongo.lock import _create_lock
|
||||
|
||||
_IS_SYNC = True
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class AsyncPeriodicExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
interval: float,
|
||||
min_interval: float,
|
||||
target: Any,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Run a target function periodically on a background task.
|
||||
|
||||
If the target's return value is false, the executor stops.
|
||||
|
||||
:param interval: Seconds between calls to `target`.
|
||||
:param min_interval: Minimum seconds between calls if `wake` is
|
||||
called very often.
|
||||
:param target: A function.
|
||||
:param name: A name to give the underlying task.
|
||||
"""
|
||||
self._event = False
|
||||
self._interval = interval
|
||||
self._min_interval = min_interval
|
||||
self._target = target
|
||||
self._stopped = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._name = name
|
||||
self._skip_sleep = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
|
||||
|
||||
def open(self) -> None:
|
||||
"""Start. Multiple calls have no effect."""
|
||||
self._stopped = False
|
||||
|
||||
if self._task is None or (
|
||||
self._task.done() and not self._task.cancelled() and not self._task.cancelling() # type: ignore[unused-ignore, attr-defined]
|
||||
):
|
||||
self._task = create_task(self._run(), name=self._name)
|
||||
|
||||
def close(self, dummy: Any = None) -> None:
|
||||
"""Stop. To restart, call open().
|
||||
|
||||
The dummy parameter allows an executor's close method to be a weakref
|
||||
callback; see monitor.py.
|
||||
"""
|
||||
self._stopped = True
|
||||
|
||||
async def join(self, timeout: Optional[int] = None) -> None:
|
||||
if self._task is not None:
|
||||
try:
|
||||
await asyncio.wait_for(self._task, timeout=timeout) # type-ignore: [arg-type]
|
||||
except asyncio.TimeoutError:
|
||||
# Task timed out
|
||||
pass
|
||||
except asyncio.exceptions.CancelledError:
|
||||
# Task was already finished, or not yet started.
|
||||
raise
|
||||
|
||||
def wake(self) -> None:
|
||||
"""Execute the target function soon."""
|
||||
self._event = True
|
||||
|
||||
def update_interval(self, new_interval: int) -> None:
|
||||
self._interval = new_interval
|
||||
|
||||
def skip_sleep(self) -> None:
|
||||
self._skip_sleep = True
|
||||
|
||||
async def _run(self) -> None:
|
||||
while not self._stopped:
|
||||
if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined]
|
||||
raise asyncio.CancelledError
|
||||
try:
|
||||
if not await self._target():
|
||||
self._stopped = True
|
||||
break
|
||||
except BaseException:
|
||||
self._stopped = True
|
||||
raise
|
||||
|
||||
if self._skip_sleep:
|
||||
self._skip_sleep = False
|
||||
else:
|
||||
deadline = time.monotonic() + self._interval
|
||||
while not self._stopped and time.monotonic() < deadline:
|
||||
await asyncio.sleep(self._min_interval)
|
||||
if self._event:
|
||||
break # Early wake.
|
||||
|
||||
self._event = False
|
||||
|
||||
|
||||
class PeriodicExecutor:
|
||||
@ -64,19 +157,6 @@ class PeriodicExecutor:
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
|
||||
|
||||
def _run_async(self) -> None:
|
||||
# The default asyncio loop implementation on Windows
|
||||
# has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240)
|
||||
# We explicitly use a different loop implementation here to prevent that issue
|
||||
if sys.platform == "win32":
|
||||
loop = asyncio.SelectorEventLoop()
|
||||
try:
|
||||
loop.run_until_complete(self._run()) # type: ignore[func-returns-value]
|
||||
finally:
|
||||
loop.close()
|
||||
else:
|
||||
asyncio.run(self._run()) # type: ignore[func-returns-value]
|
||||
|
||||
def open(self) -> None:
|
||||
"""Start. Multiple calls have no effect.
|
||||
|
||||
@ -104,10 +184,7 @@ class PeriodicExecutor:
|
||||
pass
|
||||
|
||||
if not started:
|
||||
if _IS_SYNC:
|
||||
thread = threading.Thread(target=self._run, name=self._name)
|
||||
else:
|
||||
thread = threading.Thread(target=self._run_async, name=self._name)
|
||||
thread = threading.Thread(target=self._run, name=self._name)
|
||||
thread.daemon = True
|
||||
self._thread = weakref.proxy(thread)
|
||||
_register_executor(self)
|
||||
@ -474,7 +474,6 @@ class _ClientBulk:
|
||||
if op_type == "delete":
|
||||
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
|
||||
full_result[f"{op_type}Results"][original_index] = res
|
||||
|
||||
except Exception as exc:
|
||||
# Attempt to close the cursor, then raise top-level error.
|
||||
if cmd_cursor.alive:
|
||||
|
||||
@ -77,7 +77,7 @@ class _ConnectionManager:
|
||||
def __init__(self, conn: Connection, more_to_come: bool):
|
||||
self.conn: Optional[Connection] = conn
|
||||
self.more_to_come = more_to_come
|
||||
self._alock = _create_lock()
|
||||
self._lock = _create_lock()
|
||||
|
||||
def update_exhaust(self, more_to_come: bool) -> None:
|
||||
self.more_to_come = more_to_come
|
||||
@ -1297,7 +1297,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
|
||||
>>> cursor.to_list()
|
||||
|
||||
Or, so read at most n items from the cursor::
|
||||
Or, to read at most n items from the cursor::
|
||||
|
||||
>>> cursor.to_list(n)
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""Support for explicit client-side field level encryption."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import enum
|
||||
import socket
|
||||
@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
|
||||
# BSON encoding/decoding errors are unrelated to encryption so
|
||||
# we should propagate them unchanged.
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise EncryptionError(exc) from exc
|
||||
|
||||
@ -200,6 +203,8 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
conn.close()
|
||||
except (PyMongoError, MongoCryptError):
|
||||
raise # Propagate pymongo errors directly.
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as error:
|
||||
# Wrap I/O errors in PyMongo exceptions.
|
||||
_raise_connection_failure((host, port), error)
|
||||
@ -716,6 +721,8 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
database.create_collection(name=name, **kwargs),
|
||||
encrypted_fields,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise EncryptedCollectionError(exc, encrypted_fields) from exc
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ access:
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import warnings
|
||||
@ -58,7 +59,7 @@ from typing import (
|
||||
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot, common, helpers_shared, uri_parser
|
||||
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
|
||||
from pymongo.client_options import ClientOptions
|
||||
from pymongo.errors import (
|
||||
AutoReconnect,
|
||||
@ -74,7 +75,11 @@ from pymongo.errors import (
|
||||
WaitQueueTimeoutError,
|
||||
WriteConcernError,
|
||||
)
|
||||
from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks
|
||||
from pymongo.lock import (
|
||||
_HAS_REGISTER_AT_FORK,
|
||||
_create_lock,
|
||||
_release_locks,
|
||||
)
|
||||
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
|
||||
from pymongo.message import _CursorAddress, _GetMore, _Query
|
||||
from pymongo.monitoring import ConnectionClosedReason
|
||||
@ -91,7 +96,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.results import ClientBulkWriteResult
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.synchronous import client_session, database, periodic_executor
|
||||
from pymongo.synchronous import client_session, database
|
||||
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
|
||||
from pymongo.synchronous.client_bulk import _ClientBulk
|
||||
from pymongo.synchronous.client_session import _EmptyServerSession
|
||||
@ -1716,7 +1721,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
address=address,
|
||||
)
|
||||
|
||||
with operation.conn_mgr._alock:
|
||||
with operation.conn_mgr._lock:
|
||||
with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
|
||||
err_handler.contribute_socket(operation.conn_mgr.conn)
|
||||
return server.run_operation(
|
||||
@ -1964,7 +1969,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
try:
|
||||
if conn_mgr:
|
||||
with conn_mgr._alock:
|
||||
with conn_mgr._lock:
|
||||
# Cursor is pinned to LB outside of a transaction.
|
||||
assert address is not None
|
||||
assert conn_mgr.conn is not None
|
||||
@ -2027,6 +2032,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
for address, cursor_id, conn_mgr in pinned_cursors:
|
||||
try:
|
||||
self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
# Raise the exception when client is closed so that it
|
||||
@ -2041,6 +2048,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
for address, cursor_ids in address_to_cursor_ids.items():
|
||||
try:
|
||||
self._kill_cursors(cursor_ids, address, topology, session=None)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
raise
|
||||
@ -2055,6 +2064,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
try:
|
||||
self._process_kill_cursors()
|
||||
self._topology.update_pool()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
return
|
||||
|
||||
@ -16,24 +16,24 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import time
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
|
||||
|
||||
from pymongo import common
|
||||
from pymongo import common, periodic_executor
|
||||
from pymongo._csot import MovingMinimum
|
||||
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
|
||||
from pymongo.periodic_executor import _shutdown_executors
|
||||
from pymongo.pool_options import _is_faas
|
||||
from pymongo.read_preferences import MovingAverage
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.srv_resolver import _SrvResolver
|
||||
from pymongo.synchronous import periodic_executor
|
||||
from pymongo.synchronous.periodic_executor import _shutdown_executors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext
|
||||
@ -238,6 +238,9 @@ class Monitor(MonitorBase):
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
self.close()
|
||||
finally:
|
||||
if self._executor._stopped:
|
||||
self._rtt_monitor.close()
|
||||
|
||||
def _check_server(self) -> ServerDescription:
|
||||
"""Call hello or read the next streaming response.
|
||||
@ -254,6 +257,8 @@ class Monitor(MonitorBase):
|
||||
details = cast(Mapping[str, Any], exc.details)
|
||||
self._topology.receive_cluster_time(details.get("$clusterTime"))
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except ReferenceError:
|
||||
raise
|
||||
except Exception as error:
|
||||
@ -419,6 +424,8 @@ class SrvMonitor(MonitorBase):
|
||||
if len(seedlist) == 0:
|
||||
# As per the spec: this should be treated as a failure.
|
||||
raise Exception
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
# As per the spec, upon encountering an error:
|
||||
# - An error must not be raised
|
||||
@ -482,6 +489,8 @@ class _RttMonitor(MonitorBase):
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
self.close()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
self._pool.reset()
|
||||
|
||||
@ -536,4 +545,5 @@ def _shutdown_resources() -> None:
|
||||
shutdown()
|
||||
|
||||
|
||||
atexit.register(_shutdown_resources)
|
||||
if _IS_SYNC:
|
||||
atexit.register(_shutdown_resources)
|
||||
|
||||
@ -23,7 +23,6 @@ import os
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import (
|
||||
@ -62,7 +61,11 @@ from pymongo.errors import ( # type:ignore[attr-defined]
|
||||
_CertificateError,
|
||||
)
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.lock import _create_lock, _Lock
|
||||
from pymongo.lock import (
|
||||
_cond_wait,
|
||||
_create_condition,
|
||||
_create_lock,
|
||||
)
|
||||
from pymongo.logger import (
|
||||
_CONNECTION_LOGGER,
|
||||
_ConnectionStatusMessage,
|
||||
@ -208,11 +211,6 @@ def _raise_connection_failure(
|
||||
raise AutoReconnect(msg) from error
|
||||
|
||||
|
||||
def _cond_wait(condition: threading.Condition, deadline: Optional[float]) -> bool:
|
||||
timeout = deadline - time.monotonic() if deadline else None
|
||||
return condition.wait(timeout)
|
||||
|
||||
|
||||
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
|
||||
details = {}
|
||||
timeout = _csot.get_timeout()
|
||||
@ -704,6 +702,8 @@ class Connection:
|
||||
# shutdown.
|
||||
try:
|
||||
self.conn.close()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
@ -988,8 +988,8 @@ class Pool:
|
||||
# from the right side.
|
||||
self.conns: collections.deque = collections.deque()
|
||||
self.active_contexts: set[_CancellationContext] = set()
|
||||
_lock = _create_lock()
|
||||
self.lock = _Lock(_lock)
|
||||
self.lock = _create_lock()
|
||||
self._max_connecting_cond = _create_condition(self.lock)
|
||||
self.active_sockets = 0
|
||||
# Monotonically increasing connection ID required for CMAP Events.
|
||||
self.next_connection_id = 1
|
||||
@ -1015,7 +1015,7 @@ class Pool:
|
||||
# The first portion of the wait queue.
|
||||
# Enforces: maxPoolSize
|
||||
# Also used for: clearing the wait queue
|
||||
self.size_cond = threading.Condition(_lock)
|
||||
self.size_cond = _create_condition(self.lock)
|
||||
self.requests = 0
|
||||
self.max_pool_size = self.opts.max_pool_size
|
||||
if not self.max_pool_size:
|
||||
@ -1023,7 +1023,7 @@ class Pool:
|
||||
# The second portion of the wait queue.
|
||||
# Enforces: maxConnecting
|
||||
# Also used for: clearing the wait queue
|
||||
self._max_connecting_cond = threading.Condition(_lock)
|
||||
self._max_connecting_cond = _create_condition(self.lock)
|
||||
self._max_connecting = self.opts.max_connecting
|
||||
self._pending = 0
|
||||
self._client_id = client_id
|
||||
@ -1460,7 +1460,8 @@ class Pool:
|
||||
with self.size_cond:
|
||||
self._raise_if_not_ready(checkout_started_time, emit_event=True)
|
||||
while not (self.requests < self.max_pool_size):
|
||||
if not _cond_wait(self.size_cond, deadline):
|
||||
timeout = deadline - time.monotonic() if deadline else None
|
||||
if not _cond_wait(self.size_cond, timeout):
|
||||
# Timed out, notify the next thread to ensure a
|
||||
# timeout doesn't consume the condition.
|
||||
if self.requests < self.max_pool_size:
|
||||
@ -1483,7 +1484,8 @@ class Pool:
|
||||
with self._max_connecting_cond:
|
||||
self._raise_if_not_ready(checkout_started_time, emit_event=False)
|
||||
while not (self.conns or self._pending < self._max_connecting):
|
||||
if not _cond_wait(self._max_connecting_cond, deadline):
|
||||
timeout = deadline - time.monotonic() if deadline else None
|
||||
if not _cond_wait(self._max_connecting_cond, timeout):
|
||||
# Timed out, notify the next thread to ensure a
|
||||
# timeout doesn't consume the condition.
|
||||
if self.conns or self._pending < self._max_connecting:
|
||||
|
||||
@ -27,7 +27,7 @@ import weakref
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
|
||||
|
||||
from pymongo import _csot, common, helpers_shared
|
||||
from pymongo import _csot, common, helpers_shared, periodic_executor
|
||||
from pymongo.errors import (
|
||||
ConnectionFailure,
|
||||
InvalidOperation,
|
||||
@ -39,7 +39,11 @@ from pymongo.errors import (
|
||||
WriteError,
|
||||
)
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _create_lock, _Lock
|
||||
from pymongo.lock import (
|
||||
_cond_wait,
|
||||
_create_condition,
|
||||
_create_lock,
|
||||
)
|
||||
from pymongo.logger import (
|
||||
_SDAM_LOGGER,
|
||||
_SERVER_SELECTION_LOGGER,
|
||||
@ -56,7 +60,6 @@ from pymongo.server_selectors import (
|
||||
secondary_server_selector,
|
||||
writable_server_selector,
|
||||
)
|
||||
from pymongo.synchronous import periodic_executor
|
||||
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
|
||||
from pymongo.synchronous.monitor import SrvMonitor
|
||||
from pymongo.synchronous.pool import Pool
|
||||
@ -170,9 +173,10 @@ class Topology:
|
||||
self._seed_addresses = list(topology_description.server_descriptions())
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
_lock = _create_lock()
|
||||
self._lock = _Lock(_lock)
|
||||
self._condition = self._settings.condition_class(_lock)
|
||||
self._lock = _create_lock()
|
||||
self._condition = _create_condition(
|
||||
self._lock, self._settings.condition_class if _IS_SYNC else None
|
||||
)
|
||||
self._servers: dict[_Address, Server] = {}
|
||||
self._pid: Optional[int] = None
|
||||
self._max_cluster_time: Optional[ClusterTime] = None
|
||||
@ -354,7 +358,7 @@ class Topology:
|
||||
# change, or for a timeout. We won't miss any changes that
|
||||
# came after our most recent apply_selector call, since we've
|
||||
# held the lock until now.
|
||||
self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
|
||||
_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL)
|
||||
self._description.check_compatible()
|
||||
now = time.monotonic()
|
||||
server_descriptions = self._description.apply_selector(
|
||||
@ -652,7 +656,7 @@ class Topology:
|
||||
"""Wake all monitors, wait for at least one to check its server."""
|
||||
with self._lock:
|
||||
self._request_check_all()
|
||||
self._condition.wait(wait_time)
|
||||
_cond_wait(self._condition, wait_time)
|
||||
|
||||
def data_bearing_servers(self) -> list[ServerDescription]:
|
||||
"""Return a list of all data-bearing servers.
|
||||
|
||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
@ -25,6 +26,7 @@ import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
@ -191,6 +193,8 @@ class ClientContext:
|
||||
client.close()
|
||||
|
||||
def _init_client(self):
|
||||
self.mongoses = []
|
||||
self.connection_attempts = []
|
||||
self.client = self._connect(host, port)
|
||||
if self.client is not None:
|
||||
# Return early when connected to dataLake as mongohoused does not
|
||||
@ -860,6 +864,16 @@ class ClientContext:
|
||||
client_context = ClientContext()
|
||||
|
||||
|
||||
def reset_client_context():
|
||||
if _IS_SYNC:
|
||||
# sync tests don't need to reset a client context
|
||||
return
|
||||
elif client_context.client is not None:
|
||||
client_context.client.close()
|
||||
client_context.client = None
|
||||
client_context._init_client()
|
||||
|
||||
|
||||
class PyMongoTestCase(unittest.TestCase):
|
||||
def assertEqualCommand(self, expected, actual, msg=None):
|
||||
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
||||
@ -1106,26 +1120,10 @@ class PyMongoTestCase(unittest.TestCase):
|
||||
class UnitTest(PyMongoTestCase):
|
||||
"""Async base class for TestCases that don't require a connection to MongoDB."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
def setUp(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
def tearDown(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -1136,37 +1134,20 @@ class IntegrationTest(PyMongoTestCase):
|
||||
db: Database
|
||||
credentials: Dict[str, str]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
def _setup_class(cls):
|
||||
if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
|
||||
def setUp(self) -> None:
|
||||
if not _IS_SYNC:
|
||||
reset_client_context()
|
||||
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
|
||||
raise SkipTest("this test does not support load balancers")
|
||||
if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
|
||||
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
|
||||
raise SkipTest("this test does not support serverless")
|
||||
cls.client = client_context.client
|
||||
cls.db = cls.client.pymongo_test
|
||||
self.client = client_context.client
|
||||
self.db = self.client.pymongo_test
|
||||
if client_context.auth_enabled:
|
||||
cls.credentials = {"username": db_user, "password": db_pwd}
|
||||
self.credentials = {"username": db_user, "password": db_pwd}
|
||||
else:
|
||||
cls.credentials = {}
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
pass
|
||||
self.credentials = {}
|
||||
|
||||
def cleanup_colls(self, *collections):
|
||||
"""Cleanup collections faster than drop_collection."""
|
||||
@ -1192,37 +1173,14 @@ class MockClientTest(UnitTest):
|
||||
# MockClients tests that use replicaSet, directConnection=True, pass
|
||||
# multiple seed addresses, or wait for heartbeat events are incompatible
|
||||
# with loadBalanced=True.
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
@client_context.require_no_load_balancer
|
||||
def _setup_class(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
||||
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
|
||||
|
||||
self.client_knobs.enable()
|
||||
|
||||
def tearDown(self):
|
||||
def tearDown(self) -> None:
|
||||
self.client_knobs.disable()
|
||||
super().tearDown()
|
||||
|
||||
@ -1253,7 +1211,6 @@ def teardown():
|
||||
c.drop_database("pymongo_test_mike")
|
||||
c.drop_database("pymongo_test_bernie")
|
||||
c.close()
|
||||
|
||||
print_running_clients()
|
||||
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
@ -25,6 +26,7 @@ import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
@ -191,6 +193,8 @@ class AsyncClientContext:
|
||||
await client.close()
|
||||
|
||||
async def _init_client(self):
|
||||
self.mongoses = []
|
||||
self.connection_attempts = []
|
||||
self.client = await self._connect(host, port)
|
||||
if self.client is not None:
|
||||
# Return early when connected to dataLake as mongohoused does not
|
||||
@ -862,6 +866,16 @@ class AsyncClientContext:
|
||||
async_client_context = AsyncClientContext()
|
||||
|
||||
|
||||
async def reset_client_context():
|
||||
if _IS_SYNC:
|
||||
# sync tests don't need to reset a client context
|
||||
return
|
||||
elif async_client_context.client is not None:
|
||||
await async_client_context.client.close()
|
||||
async_client_context.client = None
|
||||
await async_client_context._init_client()
|
||||
|
||||
|
||||
class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
|
||||
def assertEqualCommand(self, expected, actual, msg=None):
|
||||
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
||||
@ -1124,26 +1138,10 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
|
||||
class AsyncUnitTest(AsyncPyMongoTestCase):
|
||||
"""Async base class for TestCases that don't require a connection to MongoDB."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
async def asyncSetUp(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
async def asyncTearDown(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -1154,37 +1152,20 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
||||
db: AsyncDatabase
|
||||
credentials: Dict[str, str]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_connection
|
||||
async def _setup_class(cls):
|
||||
if async_client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
|
||||
async def asyncSetUp(self) -> None:
|
||||
if not _IS_SYNC:
|
||||
await reset_client_context()
|
||||
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
|
||||
raise SkipTest("this test does not support load balancers")
|
||||
if async_client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
|
||||
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
|
||||
raise SkipTest("this test does not support serverless")
|
||||
cls.client = async_client_context.client
|
||||
cls.db = cls.client.pymongo_test
|
||||
self.client = async_client_context.client
|
||||
self.db = self.client.pymongo_test
|
||||
if async_client_context.auth_enabled:
|
||||
cls.credentials = {"username": db_user, "password": db_pwd}
|
||||
self.credentials = {"username": db_user, "password": db_pwd}
|
||||
else:
|
||||
cls.credentials = {}
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
pass
|
||||
self.credentials = {}
|
||||
|
||||
async def cleanup_colls(self, *collections):
|
||||
"""Cleanup collections faster than drop_collection."""
|
||||
@ -1210,39 +1191,16 @@ class AsyncMockClientTest(AsyncUnitTest):
|
||||
# MockClients tests that use replicaSet, directConnection=True, pass
|
||||
# multiple seed addresses, or wait for heartbeat events are incompatible
|
||||
# with loadBalanced=True.
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_no_load_balancer
|
||||
async def _setup_class(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
|
||||
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
|
||||
|
||||
self.client_knobs.enable()
|
||||
|
||||
def tearDown(self):
|
||||
async def asyncTearDown(self) -> None:
|
||||
self.client_knobs.disable()
|
||||
super().tearDown()
|
||||
await super().asyncTearDown()
|
||||
|
||||
|
||||
async def async_setup():
|
||||
@ -1271,7 +1229,6 @@ async def async_teardown():
|
||||
await c.drop_database("pymongo_test_mike")
|
||||
await c.drop_database("pymongo_test_bernie")
|
||||
await c.close()
|
||||
|
||||
print_running_clients()
|
||||
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ def event_loop_policy():
|
||||
return asyncio.get_event_loop_policy()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
@pytest_asyncio.fixture(scope="package", autouse=True)
|
||||
async def test_setup_and_teardown():
|
||||
await async_setup()
|
||||
yield
|
||||
|
||||
@ -42,15 +42,11 @@ class AsyncBulkTestBase(AsyncIntegrationTest):
|
||||
coll: AsyncCollection
|
||||
coll_w0: AsyncCollection
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.coll = cls.db.test
|
||||
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
|
||||
|
||||
async def asyncSetUp(self):
|
||||
super().setUp()
|
||||
await super().asyncSetUp()
|
||||
self.coll = self.db.test
|
||||
await self.coll.drop()
|
||||
self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0))
|
||||
|
||||
def assertEqualResponse(self, expected, actual):
|
||||
"""Compare response from bulk.execute() to expected response."""
|
||||
@ -787,14 +783,10 @@ class AsyncTestBulk(AsyncBulkTestBase):
|
||||
|
||||
|
||||
class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase):
|
||||
@classmethod
|
||||
@async_client_context.require_auth
|
||||
@async_client_context.require_no_api_version
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
super().setUp()
|
||||
await super().asyncSetUp()
|
||||
await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"])
|
||||
await self.db.command(
|
||||
"createRole",
|
||||
@ -937,21 +929,19 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
|
||||
w: Optional[int]
|
||||
secondary: AsyncMongoClient
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.w = async_client_context.w
|
||||
cls.secondary = None
|
||||
if cls.w is not None and cls.w > 1:
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.w = async_client_context.w
|
||||
self.secondary = None
|
||||
if self.w is not None and self.w > 1:
|
||||
for member in (await async_client_context.hello)["hosts"]:
|
||||
if member != (await async_client_context.hello)["primary"]:
|
||||
cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member))
|
||||
self.secondary = await self.async_single_client(*partition_node(member))
|
||||
break
|
||||
|
||||
@classmethod
|
||||
async def async_tearDownClass(cls):
|
||||
if cls.secondary:
|
||||
await cls.secondary.close()
|
||||
async def asyncTearDown(self):
|
||||
if self.secondary:
|
||||
await self.secondary.close()
|
||||
|
||||
async def cause_wtimeout(self, requests, ordered):
|
||||
if not async_client_context.test_commands_enabled:
|
||||
|
||||
@ -836,18 +836,16 @@ class ProseSpecTestsMixin:
|
||||
class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
|
||||
dbs: list
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_version_min(4, 0, 0, -1)
|
||||
@async_client_context.require_change_streams
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.dbs = [cls.db, cls.client.pymongo_test_2]
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.dbs = [self.db, self.client.pymongo_test_2]
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
for db in cls.dbs:
|
||||
await cls.client.drop_database(db)
|
||||
await super()._tearDown_class()
|
||||
async def asyncTearDown(self):
|
||||
for db in self.dbs:
|
||||
await self.client.drop_database(db)
|
||||
await super().asyncTearDown()
|
||||
|
||||
async def change_stream_with_client(self, client, *args, **kwargs):
|
||||
return await client.watch(*args, **kwargs)
|
||||
@ -898,11 +896,10 @@ class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
|
||||
|
||||
|
||||
class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
|
||||
@classmethod
|
||||
@async_client_context.require_version_min(4, 0, 0, -1)
|
||||
@async_client_context.require_change_streams
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
|
||||
async def change_stream_with_client(self, client, *args, **kwargs):
|
||||
return await client[self.db.name].watch(*args, **kwargs)
|
||||
@ -988,12 +985,9 @@ class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixi
|
||||
class TestAsyncCollectionAsyncChangeStream(
|
||||
TestAsyncChangeStreamBase, APITestsMixin, ProseSpecTestsMixin
|
||||
):
|
||||
@classmethod
|
||||
@async_client_context.require_change_streams
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
# Use a new collection for each test.
|
||||
await self.watched_collection().drop()
|
||||
await self.watched_collection().insert_one({})
|
||||
@ -1133,20 +1127,11 @@ class TestAllLegacyScenarios(AsyncIntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = True
|
||||
listener: AllowListEventListener
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_connection
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.listener = AllowListEventListener("aggregate", "getMore")
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
def asyncSetUp(self):
|
||||
super().asyncSetUp()
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.listener = AllowListEventListener("aggregate", "getMore")
|
||||
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
|
||||
self.listener.reset()
|
||||
|
||||
async def asyncSetUpCluster(self, scenario_dict):
|
||||
|
||||
@ -73,7 +73,6 @@ from test.utils import (
|
||||
is_greenthread_patched,
|
||||
lazy_client_trial,
|
||||
one,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
import bson
|
||||
@ -131,16 +130,11 @@ class AsyncClientUnitTest(AsyncUnitTest):
|
||||
|
||||
client: AsyncMongoClient
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(
|
||||
async def asyncSetUp(self) -> None:
|
||||
self.client = await self.async_rs_or_single_client(
|
||||
connect=False, serverSelectionTimeoutMS=100
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.close()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
@ -693,8 +687,8 @@ class TestClient(AsyncIntegrationTest):
|
||||
# When the reaper runs at the same time as the get_socket, two
|
||||
# connections could be created and checked into the pool.
|
||||
self.assertGreaterEqual(len(server._pool.conns), 1)
|
||||
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
||||
wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket")
|
||||
await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
||||
await async_wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket")
|
||||
|
||||
async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
@ -710,8 +704,8 @@ class TestClient(AsyncIntegrationTest):
|
||||
# When the reaper runs at the same time as the get_socket,
|
||||
# maxPoolSize=1 should prevent two connections from being created.
|
||||
self.assertEqual(1, len(server._pool.conns))
|
||||
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
||||
wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket")
|
||||
await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
||||
await async_wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket")
|
||||
|
||||
async def test_max_idle_time_reaper_removes_stale(self):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
@ -727,7 +721,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
async with server._pool.checkout() as conn_two:
|
||||
pass
|
||||
self.assertIs(conn_one, conn_two)
|
||||
wait_until(
|
||||
await async_wait_until(
|
||||
lambda: len(server._pool.conns) == 0,
|
||||
"stale socket reaped and new one NOT added to the pool",
|
||||
)
|
||||
@ -745,7 +739,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
server = await (await client._get_topology()).select_server(
|
||||
readable_server_selector, _Op.TEST
|
||||
)
|
||||
wait_until(
|
||||
await async_wait_until(
|
||||
lambda: len(server._pool.conns) == 10,
|
||||
"pool initialized with 10 connections",
|
||||
)
|
||||
@ -753,7 +747,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
# Assert that if a socket is closed, a new one takes its place
|
||||
async with server._pool.checkout() as conn:
|
||||
conn.close_conn(None)
|
||||
wait_until(
|
||||
await async_wait_until(
|
||||
lambda: len(server._pool.conns) == 10,
|
||||
"a closed socket gets replaced from the pool",
|
||||
)
|
||||
@ -939,8 +933,10 @@ class TestClient(AsyncIntegrationTest):
|
||||
async with eval(the_repr) as client_two:
|
||||
self.assertEqual(client_two, client)
|
||||
|
||||
def test_getters(self):
|
||||
wait_until(lambda: async_client_context.nodes == self.client.nodes, "find all nodes")
|
||||
async def test_getters(self):
|
||||
await async_wait_until(
|
||||
lambda: async_client_context.nodes == self.client.nodes, "find all nodes"
|
||||
)
|
||||
|
||||
async def test_list_databases(self):
|
||||
cmd_docs = (await self.client.admin.command("listDatabases"))["databases"]
|
||||
@ -1065,14 +1061,21 @@ class TestClient(AsyncIntegrationTest):
|
||||
self.assertFalse(client._topology._opened)
|
||||
|
||||
# Ensure kill cursors thread has not been started.
|
||||
kc_thread = client._kill_cursors_executor._thread
|
||||
self.assertFalse(kc_thread and kc_thread.is_alive())
|
||||
|
||||
if _IS_SYNC:
|
||||
kc_thread = client._kill_cursors_executor._thread
|
||||
self.assertFalse(kc_thread and kc_thread.is_alive())
|
||||
else:
|
||||
kc_task = client._kill_cursors_executor._task
|
||||
self.assertFalse(kc_task and not kc_task.done())
|
||||
# Using the client should open topology and start the thread.
|
||||
await client.admin.command("ping")
|
||||
self.assertTrue(client._topology._opened)
|
||||
kc_thread = client._kill_cursors_executor._thread
|
||||
self.assertTrue(kc_thread and kc_thread.is_alive())
|
||||
if _IS_SYNC:
|
||||
kc_thread = client._kill_cursors_executor._thread
|
||||
self.assertTrue(kc_thread and kc_thread.is_alive())
|
||||
else:
|
||||
kc_task = client._kill_cursors_executor._task
|
||||
self.assertTrue(kc_task and not kc_task.done())
|
||||
|
||||
async def test_close_does_not_open_servers(self):
|
||||
client = await self.async_rs_client(connect=False)
|
||||
@ -1277,6 +1280,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
async def test_server_selection_timeout(self):
|
||||
client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False)
|
||||
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
||||
await client.close()
|
||||
|
||||
client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False)
|
||||
|
||||
@ -1289,18 +1293,22 @@ class TestClient(AsyncIntegrationTest):
|
||||
self.assertRaises(
|
||||
ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False
|
||||
)
|
||||
await client.close()
|
||||
|
||||
client = AsyncMongoClient(
|
||||
"mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False
|
||||
)
|
||||
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
||||
await client.close()
|
||||
|
||||
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
|
||||
self.assertAlmostEqual(0, client.options.server_selection_timeout)
|
||||
await client.close()
|
||||
|
||||
# Test invalid timeout in URI ignored and set to default.
|
||||
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
|
||||
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
||||
await client.close()
|
||||
|
||||
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
|
||||
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
||||
@ -1608,7 +1616,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
await async_client_context.port,
|
||||
)
|
||||
await self.async_single_client(uri, event_listeners=[listener])
|
||||
wait_until(
|
||||
await async_wait_until(
|
||||
lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents"
|
||||
)
|
||||
|
||||
@ -1766,16 +1774,16 @@ class TestClient(AsyncIntegrationTest):
|
||||
pool = await async_get_pool(client)
|
||||
original_connect = pool.connect
|
||||
|
||||
def stall_connect(*args, **kwargs):
|
||||
time.sleep(2)
|
||||
return original_connect(*args, **kwargs)
|
||||
async def stall_connect(*args, **kwargs):
|
||||
await asyncio.sleep(2)
|
||||
return await original_connect(*args, **kwargs)
|
||||
|
||||
pool.connect = stall_connect
|
||||
# Un-patch Pool.connect to break the cyclic reference.
|
||||
self.addCleanup(delattr, pool, "connect")
|
||||
|
||||
# Wait for the background thread to start creating connections
|
||||
wait_until(lambda: len(pool.conns) > 1, "start creating connections")
|
||||
await async_wait_until(lambda: len(pool.conns) > 1, "start creating connections")
|
||||
|
||||
# Assert that application operations do not block.
|
||||
for _ in range(10):
|
||||
@ -1858,7 +1866,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
await client.close()
|
||||
# Add cursor to kill cursors queue
|
||||
del cursor
|
||||
wait_until(
|
||||
await async_wait_until(
|
||||
lambda: client._kill_cursors_queue,
|
||||
"waited for cursor to be added to queue",
|
||||
)
|
||||
@ -2232,7 +2240,7 @@ class TestExhaustCursor(AsyncIntegrationTest):
|
||||
await cursor.to_list()
|
||||
self.assertTrue(conn.closed)
|
||||
|
||||
wait_until(
|
||||
await async_wait_until(
|
||||
lambda: len(client._kill_cursors_queue) == 0,
|
||||
"waited for all killCursor requests to complete",
|
||||
)
|
||||
@ -2403,7 +2411,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
|
||||
)
|
||||
self.addAsyncCleanup(c.close)
|
||||
|
||||
wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
|
||||
self.assertEqual(await c.address, ("a", 1))
|
||||
# Fail over.
|
||||
@ -2430,7 +2438,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
|
||||
)
|
||||
self.addAsyncCleanup(c.close)
|
||||
|
||||
wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
|
||||
# Total failure.
|
||||
c.kill_host("a:1")
|
||||
@ -2472,7 +2480,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
|
||||
c.set_wire_version_range("a:1", 2, MIN_SUPPORTED_WIRE_VERSION)
|
||||
c.set_wire_version_range("b:2", 2, MIN_SUPPORTED_WIRE_VERSION + 1)
|
||||
await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
|
||||
wait_until(lambda: len(c.nodes) == 2, "connect")
|
||||
await async_wait_until(lambda: len(c.nodes) == 2, "connect")
|
||||
|
||||
c.kill_host("a:1")
|
||||
|
||||
@ -2544,11 +2552,11 @@ class TestClientPool(AsyncMockClientTest):
|
||||
)
|
||||
self.addAsyncCleanup(c.close)
|
||||
|
||||
wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
self.assertEqual(await c.address, ("a", 1))
|
||||
self.assertEqual(await c.arbiters, {("c", 3)})
|
||||
# Assert that we create 2 and only 2 pooled connections.
|
||||
listener.wait_for_event(monitoring.ConnectionReadyEvent, 2)
|
||||
await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 2)
|
||||
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2)
|
||||
# Assert that we do not create connections to arbiters.
|
||||
arbiter = c._topology.get_server_by_address(("c", 3))
|
||||
@ -2574,10 +2582,10 @@ class TestClientPool(AsyncMockClientTest):
|
||||
)
|
||||
self.addAsyncCleanup(c.close)
|
||||
|
||||
wait_until(lambda: len(c.nodes) == 1, "connect")
|
||||
await async_wait_until(lambda: len(c.nodes) == 1, "connect")
|
||||
self.assertEqual(await c.address, ("c", 3))
|
||||
# Assert that we create 1 pooled connection.
|
||||
listener.wait_for_event(monitoring.ConnectionReadyEvent, 1)
|
||||
await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 1)
|
||||
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1)
|
||||
arbiter = c._topology.get_server_by_address(("c", 3))
|
||||
self.assertEqual(len(arbiter.pool.conns), 1)
|
||||
|
||||
@ -25,7 +25,7 @@ _IS_SYNC = False
|
||||
|
||||
class TestAsyncClientContext(AsyncUnitTest):
|
||||
def test_must_connect(self):
|
||||
if "PYMONGO_MUST_CONNECT" not in os.environ:
|
||||
if not os.environ.get("PYMONGO_MUST_CONNECT"):
|
||||
raise SkipTest("PYMONGO_MUST_CONNECT is not set")
|
||||
|
||||
self.assertTrue(
|
||||
@ -37,7 +37,7 @@ class TestAsyncClientContext(AsyncUnitTest):
|
||||
)
|
||||
|
||||
def test_serverless(self):
|
||||
if "TEST_SERVERLESS" not in os.environ:
|
||||
if not os.environ.get("TEST_SERVERLESS"):
|
||||
raise SkipTest("TEST_SERVERLESS is not set")
|
||||
|
||||
self.assertTrue(
|
||||
@ -47,7 +47,7 @@ class TestAsyncClientContext(AsyncUnitTest):
|
||||
)
|
||||
|
||||
def test_enableTestCommands_is_disabled(self):
|
||||
if "PYMONGO_DISABLE_TEST_COMMANDS" not in os.environ:
|
||||
if not os.environ.get("PYMONGO_DISABLE_TEST_COMMANDS"):
|
||||
raise SkipTest("PYMONGO_DISABLE_TEST_COMMANDS is not set")
|
||||
|
||||
self.assertFalse(
|
||||
@ -56,7 +56,7 @@ class TestAsyncClientContext(AsyncUnitTest):
|
||||
)
|
||||
|
||||
def test_setdefaultencoding_worked(self):
|
||||
if "SETDEFAULTENCODING" not in os.environ:
|
||||
if not os.environ.get("SETDEFAULTENCODING"):
|
||||
raise SkipTest("SETDEFAULTENCODING is not set")
|
||||
|
||||
self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"])
|
||||
|
||||
@ -97,28 +97,21 @@ class TestCollation(AsyncIntegrationTest):
|
||||
warn_context: Any
|
||||
collation: Collation
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_connection
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.db = cls.client.pymongo_test
|
||||
cls.collation = Collation("en_US")
|
||||
cls.warn_context = warnings.catch_warnings()
|
||||
cls.warn_context.__enter__()
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.listener = OvertCommandListener()
|
||||
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
|
||||
self.db = self.client.pymongo_test
|
||||
self.collation = Collation("en_US")
|
||||
self.warn_context = warnings.catch_warnings()
|
||||
self.warn_context.__enter__()
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.warn_context.__exit__()
|
||||
cls.warn_context = None
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
def tearDown(self):
|
||||
async def asyncTearDown(self) -> None:
|
||||
self.warn_context.__exit__()
|
||||
self.warn_context = None
|
||||
self.listener.reset()
|
||||
super().tearDown()
|
||||
await super().asyncTearDown()
|
||||
|
||||
def last_command_started(self):
|
||||
return self.listener.started_events[-1].command
|
||||
|
||||
@ -40,7 +40,6 @@ from test.utils import (
|
||||
async_get_pool,
|
||||
async_is_mongos,
|
||||
async_wait_until,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
from bson import encode
|
||||
@ -88,14 +87,10 @@ class TestCollectionNoConnect(AsyncUnitTest):
|
||||
db: AsyncDatabase
|
||||
client: AsyncMongoClient
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
cls.client = AsyncMongoClient(connect=False)
|
||||
cls.db = cls.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.close()
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.client = self.simple_client(connect=False)
|
||||
self.db = self.client.pymongo_test
|
||||
|
||||
def test_collection(self):
|
||||
self.assertRaises(TypeError, AsyncCollection, self.db, 5)
|
||||
@ -165,27 +160,14 @@ class TestCollectionNoConnect(AsyncUnitTest):
|
||||
class AsyncTestCollection(AsyncIntegrationTest):
|
||||
w: int
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.w = async_client_context.w # type: ignore
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine]
|
||||
else:
|
||||
asyncio.run(cls.async_tearDownClass())
|
||||
|
||||
@classmethod
|
||||
async def async_tearDownClass(cls):
|
||||
await cls.db.drop_collection("test_large_limit")
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await self.db.test.drop()
|
||||
await super().asyncSetUp()
|
||||
self.w = async_client_context.w # type: ignore
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.db.test.drop()
|
||||
await self.db.drop_collection("test_large_limit")
|
||||
await super().asyncTearDown()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def write_concern_collection(self):
|
||||
@ -1023,7 +1005,10 @@ class AsyncTestCollection(AsyncIntegrationTest):
|
||||
await db.test.insert_one({"y": 1}, bypass_document_validation=True)
|
||||
await db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True)
|
||||
|
||||
await async_wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document")
|
||||
async def predicate():
|
||||
return await db_w0.test.find_one({"x": 1})
|
||||
|
||||
await async_wait_until(predicate, "find w:0 replaced document")
|
||||
|
||||
async def test_update_bypass_document_validation(self):
|
||||
db = self.db
|
||||
@ -1871,7 +1856,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
|
||||
await cur.close()
|
||||
cur = None
|
||||
# Wait until the background thread returns the socket.
|
||||
wait_until(lambda: pool.active_sockets == 0, "return socket")
|
||||
await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
|
||||
# The socket should be discarded.
|
||||
self.assertEqual(0, len(pool.conns))
|
||||
|
||||
|
||||
@ -19,7 +19,12 @@ import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||
from test.asynchronous import (
|
||||
AsyncIntegrationTest,
|
||||
async_client_context,
|
||||
reset_client_context,
|
||||
unittest,
|
||||
)
|
||||
from test.asynchronous.helpers import async_repl_set_step_down
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
@ -39,29 +44,19 @@ class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest):
|
||||
listener: CMAPListener
|
||||
coll: AsyncCollection
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_replica_set
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.listener = CMAPListener()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(
|
||||
event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500
|
||||
async def asyncSetUp(self):
|
||||
self.listener = CMAPListener()
|
||||
self.client = await self.async_rs_or_single_client(
|
||||
event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500
|
||||
)
|
||||
|
||||
# Ensure connections to all servers in replica set. This is to test
|
||||
# that the is_writable flag is properly updated for connections that
|
||||
# survive a replica set election.
|
||||
await async_ensure_all_connected(cls.client)
|
||||
cls.listener.reset()
|
||||
|
||||
cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority"))
|
||||
cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority"))
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.close()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await async_ensure_all_connected(self.client)
|
||||
self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority"))
|
||||
self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority"))
|
||||
# Note that all ops use same write-concern as self.db (majority).
|
||||
await self.db.drop_collection("step-down")
|
||||
await self.db.create_collection("step-down")
|
||||
|
||||
@ -56,6 +56,9 @@ class TestCreateEntities(AsyncIntegrationTest):
|
||||
self.assertGreater(len(final_entity_map["events1"]), 0)
|
||||
for event in final_entity_map["events1"]:
|
||||
self.assertIn("PoolCreatedEvent", event["name"])
|
||||
if self.scenario_runner.mongos_clients:
|
||||
for client in self.scenario_runner.mongos_clients:
|
||||
await client.close()
|
||||
|
||||
async def test_store_all_others_as_entities(self):
|
||||
self.scenario_runner = UnifiedSpecTestMixinV1()
|
||||
@ -122,6 +125,9 @@ class TestCreateEntities(AsyncIntegrationTest):
|
||||
self.assertEqual(entity_map["failures"], [])
|
||||
self.assertEqual(entity_map["successes"], 2)
|
||||
self.assertEqual(entity_map["iterations"], 5)
|
||||
if self.scenario_runner.mongos_clients:
|
||||
for client in self.scenario_runner.mongos_clients:
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -34,9 +34,9 @@ from test.utils import (
|
||||
AllowListEventListener,
|
||||
EventListener,
|
||||
OvertCommandListener,
|
||||
async_wait_until,
|
||||
delay,
|
||||
ignore_deprecations,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
from bson import decode_all
|
||||
@ -1324,8 +1324,8 @@ class TestCursor(AsyncIntegrationTest):
|
||||
with self.assertRaises(ExecutionTimeout):
|
||||
await cursor.next()
|
||||
|
||||
def assertCursorKilled():
|
||||
wait_until(
|
||||
async def assertCursorKilled():
|
||||
await async_wait_until(
|
||||
lambda: len(listener.succeeded_events),
|
||||
"find successful killCursors command",
|
||||
)
|
||||
@ -1335,7 +1335,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
self.assertEqual(1, len(listener.succeeded_events))
|
||||
self.assertEqual("killCursors", listener.succeeded_events[0].command_name)
|
||||
|
||||
assertCursorKilled()
|
||||
await assertCursorKilled()
|
||||
listener.reset()
|
||||
|
||||
cursor = await coll.aggregate([], batchSize=1)
|
||||
@ -1345,7 +1345,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
with self.assertRaises(ExecutionTimeout):
|
||||
await cursor.next()
|
||||
|
||||
assertCursorKilled()
|
||||
await assertCursorKilled()
|
||||
|
||||
def test_delete_not_initialized(self):
|
||||
# Creating a cursor with invalid arguments will not run __init__
|
||||
@ -1647,10 +1647,6 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
|
||||
|
||||
class TestRawBatchCommandCursor(AsyncIntegrationTest):
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
|
||||
async def test_aggregate_raw(self):
|
||||
c = self.db.test
|
||||
await c.drop()
|
||||
|
||||
@ -717,7 +717,8 @@ class TestDatabase(AsyncIntegrationTest):
|
||||
|
||||
|
||||
class TestDatabaseAggregation(AsyncIntegrationTest):
|
||||
def setUp(self):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.pipeline: List[Mapping[str, Any]] = [
|
||||
{"$listLocalSessions": {}},
|
||||
{"$limit": 1},
|
||||
|
||||
@ -211,11 +211,10 @@ class TestClientOptions(AsyncPyMongoTestCase):
|
||||
class AsyncEncryptionIntegrationTest(AsyncIntegrationTest):
|
||||
"""Base class for encryption integration tests."""
|
||||
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
@async_client_context.require_version_min(4, 2, -1)
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
|
||||
def assertEncrypted(self, val):
|
||||
self.assertIsInstance(val, Binary)
|
||||
@ -430,10 +429,9 @@ class TestEncryptedBulkWrite(AsyncBulkTestBase, AsyncEncryptionIntegrationTest):
|
||||
|
||||
|
||||
class TestClientMaxWireVersion(AsyncIntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
|
||||
@async_client_context.require_version_max(4, 0, 99)
|
||||
async def test_raise_max_wire_version_error(self):
|
||||
@ -818,17 +816,16 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
|
||||
"local": None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@unittest.skipUnless(
|
||||
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
||||
"No environment credentials are set",
|
||||
)
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
await cls.client.db.coll.drop()
|
||||
cls.vault = await create_key_vault(cls.client.keyvault.datakeys)
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.listener = OvertCommandListener()
|
||||
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
|
||||
await self.client.db.coll.drop()
|
||||
self.vault = await create_key_vault(self.client.keyvault.datakeys)
|
||||
|
||||
# Configure the encrypted field via the local schema_map option.
|
||||
schemas = {
|
||||
@ -846,25 +843,22 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
|
||||
}
|
||||
}
|
||||
opts = AutoEncryptionOpts(
|
||||
cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS
|
||||
self.KMS_PROVIDERS,
|
||||
"keyvault.datakeys",
|
||||
schema_map=schemas,
|
||||
kms_tls_options=KMS_TLS_OPTS,
|
||||
)
|
||||
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client(
|
||||
self.client_encrypted = await self.async_rs_or_single_client(
|
||||
auto_encryption_opts=opts, uuidRepresentation="standard"
|
||||
)
|
||||
cls.client_encryption = cls.unmanaged_create_client_encryption(
|
||||
cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.vault.drop()
|
||||
await cls.client.close()
|
||||
await cls.client_encrypted.close()
|
||||
await cls.client_encryption.close()
|
||||
|
||||
def setUp(self):
|
||||
self.listener.reset()
|
||||
|
||||
async def asyncTearDown(self) -> None:
|
||||
await self.vault.drop()
|
||||
|
||||
async def run_test(self, provider_name):
|
||||
# Create data key.
|
||||
master_key: Any = self.MASTER_KEYS[provider_name]
|
||||
@ -1011,10 +1005,9 @@ class TestViews(AsyncEncryptionIntegrationTest):
|
||||
|
||||
|
||||
class TestCorpus(AsyncEncryptionIntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
|
||||
@staticmethod
|
||||
def kms_providers():
|
||||
@ -1188,12 +1181,11 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
||||
client_encrypted: AsyncMongoClient
|
||||
listener: OvertCommandListener
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
db = async_client_context.client.db
|
||||
cls.coll = db.coll
|
||||
await cls.coll.drop()
|
||||
self.coll = db.coll
|
||||
await self.coll.drop()
|
||||
# Configure the encrypted 'db.coll' collection via jsonSchema.
|
||||
json_schema = json_data("limits", "limits-schema.json")
|
||||
await db.create_collection(
|
||||
@ -1211,17 +1203,14 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
||||
await coll.insert_one(json_data("limits", "limits-key.json"))
|
||||
|
||||
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client(
|
||||
auto_encryption_opts=opts, event_listeners=[cls.listener]
|
||||
self.listener = OvertCommandListener()
|
||||
self.client_encrypted = await self.async_rs_or_single_client(
|
||||
auto_encryption_opts=opts, event_listeners=[self.listener]
|
||||
)
|
||||
cls.coll_encrypted = cls.client_encrypted.db.coll
|
||||
self.coll_encrypted = self.client_encrypted.db.coll
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.coll_encrypted.drop()
|
||||
await cls.client_encrypted.close()
|
||||
await super()._tearDown_class()
|
||||
async def asyncTearDown(self) -> None:
|
||||
await self.coll_encrypted.drop()
|
||||
|
||||
async def test_01_insert_succeeds_under_2MiB(self):
|
||||
doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB}
|
||||
@ -1245,7 +1234,9 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
||||
doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB}
|
||||
self.listener.reset()
|
||||
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
|
||||
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
|
||||
self.assertEqual(
|
||||
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
|
||||
)
|
||||
|
||||
async def test_04_bulk_batch_split(self):
|
||||
limits_doc = json_data("limits", "limits-doc.json")
|
||||
@ -1255,7 +1246,9 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
||||
doc2.update(limits_doc)
|
||||
self.listener.reset()
|
||||
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
|
||||
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
|
||||
self.assertEqual(
|
||||
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
|
||||
)
|
||||
|
||||
async def test_05_insert_succeeds_just_under_16MiB(self):
|
||||
doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)}
|
||||
@ -1285,15 +1278,12 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
||||
class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
|
||||
"""Prose tests for creating data keys with a custom endpoint."""
|
||||
|
||||
@classmethod
|
||||
@unittest.skipUnless(
|
||||
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
||||
"No environment credentials are set",
|
||||
)
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
|
||||
def setUp(self):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
kms_providers = {
|
||||
"aws": AWS_CREDS,
|
||||
"azure": AZURE_CREDS,
|
||||
@ -1322,10 +1312,6 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
|
||||
self._kmip_host_error = None
|
||||
self._invalid_host_error = None
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.client_encryption.close()
|
||||
await self.client_encryption_invalid.close()
|
||||
|
||||
async def run_test_expected_success(self, provider_name, master_key):
|
||||
data_key_id = await self.client_encryption.create_data_key(
|
||||
provider_name, master_key=master_key
|
||||
@ -1500,18 +1486,18 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
|
||||
KEYVAULT_COLL = "datakeys"
|
||||
client: AsyncMongoClient
|
||||
|
||||
async def asyncSetUp(self):
|
||||
async def _setup(self):
|
||||
keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL)
|
||||
await create_key_vault(keyvault, self.DEK)
|
||||
|
||||
async def _test_explicit(self, expectation):
|
||||
await self._setup()
|
||||
client_encryption = self.create_client_encryption(
|
||||
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
|
||||
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
|
||||
async_client_context.client,
|
||||
OPTS,
|
||||
)
|
||||
self.addAsyncCleanup(client_encryption.close)
|
||||
|
||||
ciphertext = await client_encryption.encrypt(
|
||||
"string0",
|
||||
@ -1523,6 +1509,7 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
|
||||
self.assertEqual(await client_encryption.decrypt(ciphertext), "string0")
|
||||
|
||||
async def _test_automatic(self, expectation_extjson, payload):
|
||||
await self._setup()
|
||||
encrypted_db = "db"
|
||||
encrypted_coll = "coll"
|
||||
keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
|
||||
@ -1537,7 +1524,6 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
|
||||
client = await self.async_rs_or_single_client(
|
||||
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
|
||||
)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
|
||||
coll = client.get_database(encrypted_db).get_collection(
|
||||
encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority")
|
||||
@ -1559,13 +1545,12 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
|
||||
|
||||
|
||||
class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set")
|
||||
async def _setup_class(cls):
|
||||
cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
|
||||
cls.DEK = json_data(BASE, "custom", "azure-dek.json")
|
||||
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self):
|
||||
self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
|
||||
self.DEK = json_data(BASE, "custom", "azure-dek.json")
|
||||
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||
await super().asyncSetUp()
|
||||
|
||||
async def test_explicit(self):
|
||||
return await self._test_explicit(
|
||||
@ -1585,13 +1570,12 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegratio
|
||||
|
||||
|
||||
class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set")
|
||||
async def _setup_class(cls):
|
||||
cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
|
||||
cls.DEK = json_data(BASE, "custom", "gcp-dek.json")
|
||||
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self):
|
||||
self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
|
||||
self.DEK = json_data(BASE, "custom", "gcp-dek.json")
|
||||
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||
await super().asyncSetUp()
|
||||
|
||||
async def test_explicit(self):
|
||||
return await self._test_explicit(
|
||||
@ -1613,6 +1597,7 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationT
|
||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests
|
||||
class TestDeadlockProse(AsyncEncryptionIntegrationTest):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.client_test = await self.async_rs_or_single_client(
|
||||
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
|
||||
)
|
||||
@ -1645,7 +1630,6 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
|
||||
self.ciphertext = await client_encryption.encrypt(
|
||||
"string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local"
|
||||
)
|
||||
await client_encryption.close()
|
||||
|
||||
self.client_listener = OvertCommandListener()
|
||||
self.topology_listener = TopologyEventListener()
|
||||
@ -1840,6 +1824,7 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
|
||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events
|
||||
class TestDecryptProse(AsyncEncryptionIntegrationTest):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.client = async_client_context.client
|
||||
await self.client.db.drop_collection("decryption_events")
|
||||
await create_key_vault(self.client.keyvault.datakeys)
|
||||
@ -2275,6 +2260,7 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest):
|
||||
# https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames
|
||||
class TestUniqueIndexOnKeyAltNamesProse(AsyncEncryptionIntegrationTest):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.client = async_client_context.client
|
||||
await create_key_vault(self.client.keyvault.datakeys)
|
||||
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
|
||||
@ -2624,8 +2610,6 @@ class TestQueryableEncryptionDocsExample(AsyncEncryptionIntegrationTest):
|
||||
assert isinstance(res["encrypted_indexed"], Binary)
|
||||
assert isinstance(res["encrypted_unindexed"], Binary)
|
||||
|
||||
await client_encryption.close()
|
||||
|
||||
|
||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption
|
||||
class TestRangeQueryProse(AsyncEncryptionIntegrationTest):
|
||||
@ -3085,21 +3069,15 @@ def start_mongocryptd(port) -> None:
|
||||
_spawn_daemon(args)
|
||||
|
||||
|
||||
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
|
||||
class TestNoSessionsSupport(AsyncEncryptionIntegrationTest):
|
||||
mongocryptd_client: AsyncMongoClient
|
||||
MONGOCRYPTD_PORT = 27020
|
||||
|
||||
@classmethod
|
||||
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
start_mongocryptd(cls.MONGOCRYPTD_PORT)
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await super()._tearDown_class()
|
||||
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
start_mongocryptd(self.MONGOCRYPTD_PORT)
|
||||
|
||||
self.listener = OvertCommandListener()
|
||||
self.mongocryptd_client = self.simple_client(
|
||||
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]
|
||||
|
||||
@ -97,6 +97,7 @@ class AsyncTestGridFileNoConnect(AsyncUnitTest):
|
||||
|
||||
class AsyncTestGridFile(AsyncIntegrationTest):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks)
|
||||
|
||||
async def test_basic(self):
|
||||
|
||||
@ -16,498 +16,447 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import threading
|
||||
import unittest
|
||||
|
||||
from pymongo.lock import _async_create_condition, _async_create_lock
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from pymongo.lock import _ACondition
|
||||
if sys.version_info < (3, 13):
|
||||
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
|
||||
# Includes tests for:
|
||||
# - https://github.com/python/cpython/issues/111693
|
||||
# - https://github.com/python/cpython/issues/112202
|
||||
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_wait(self):
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
result = []
|
||||
|
||||
|
||||
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
|
||||
# Includes tests for:
|
||||
# - https://github.com/python/cpython/issues/111693
|
||||
# - https://github.com/python/cpython/issues/112202
|
||||
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_wait(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait():
|
||||
result.append(1)
|
||||
return True
|
||||
|
||||
async def c2(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait():
|
||||
result.append(2)
|
||||
return True
|
||||
|
||||
async def c3(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait():
|
||||
result.append(3)
|
||||
return True
|
||||
|
||||
t1 = asyncio.create_task(c1(result))
|
||||
t2 = asyncio.create_task(c2(result))
|
||||
t3 = asyncio.create_task(c3(result))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
self.assertFalse(cond.locked())
|
||||
|
||||
self.assertTrue(await cond.acquire())
|
||||
cond.notify()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.notify(2)
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1, 2], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1, 2, 3], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
self.assertTrue(t1.done())
|
||||
self.assertTrue(t1.result())
|
||||
self.assertTrue(t2.done())
|
||||
self.assertTrue(t2.result())
|
||||
self.assertTrue(t3.done())
|
||||
self.assertTrue(t3.result())
|
||||
|
||||
async def test_wait_cancel(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
await cond.acquire()
|
||||
|
||||
wait = asyncio.create_task(cond.wait())
|
||||
asyncio.get_running_loop().call_soon(wait.cancel)
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await wait
|
||||
self.assertFalse(cond._waiters)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
async def test_wait_cancel_contested(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
await cond.acquire()
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
wait_task = asyncio.create_task(cond.wait())
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(cond.locked())
|
||||
|
||||
# Notify, but contest the lock before cancelling
|
||||
await cond.acquire()
|
||||
self.assertTrue(cond.locked())
|
||||
cond.notify()
|
||||
asyncio.get_running_loop().call_soon(wait_task.cancel)
|
||||
asyncio.get_running_loop().call_soon(cond.release)
|
||||
|
||||
try:
|
||||
await wait_task
|
||||
except asyncio.CancelledError:
|
||||
# Should not happen, since no cancellation points
|
||||
pass
|
||||
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
async def test_wait_cancel_after_notify(self):
|
||||
# See bpo-32841
|
||||
waited = False
|
||||
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
async def wait_on_cond():
|
||||
nonlocal waited
|
||||
async with cond:
|
||||
waited = True # Make sure this area was reached
|
||||
await cond.wait()
|
||||
|
||||
waiter = asyncio.create_task(wait_on_cond())
|
||||
await asyncio.sleep(0) # Start waiting
|
||||
|
||||
await cond.acquire()
|
||||
cond.notify()
|
||||
await asyncio.sleep(0) # Get to acquire()
|
||||
waiter.cancel()
|
||||
await asyncio.sleep(0) # Activate cancellation
|
||||
cond.release()
|
||||
await asyncio.sleep(0) # Cancellation should occur
|
||||
|
||||
self.assertTrue(waiter.cancelled())
|
||||
self.assertTrue(waited)
|
||||
|
||||
async def test_wait_unacquired(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
with self.assertRaises(RuntimeError):
|
||||
await cond.wait()
|
||||
|
||||
async def test_wait_for(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
presult = False
|
||||
|
||||
def predicate():
|
||||
return presult
|
||||
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait_for(predicate):
|
||||
result.append(1)
|
||||
cond.release()
|
||||
return True
|
||||
|
||||
t = asyncio.create_task(c1(result))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
await cond.acquire()
|
||||
cond.notify()
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
presult = True
|
||||
await cond.acquire()
|
||||
cond.notify()
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1], result)
|
||||
|
||||
self.assertTrue(t.done())
|
||||
self.assertTrue(t.result())
|
||||
|
||||
async def test_wait_for_unacquired(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
# predicate can return true immediately
|
||||
res = await cond.wait_for(lambda: [1, 2, 3])
|
||||
self.assertEqual([1, 2, 3], res)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
await cond.wait_for(lambda: False)
|
||||
|
||||
async def test_notify(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
async with cond:
|
||||
async def c1(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait():
|
||||
result.append(1)
|
||||
return True
|
||||
|
||||
async def c2(result):
|
||||
async with cond:
|
||||
async def c2(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait():
|
||||
result.append(2)
|
||||
return True
|
||||
|
||||
async def c3(result):
|
||||
async with cond:
|
||||
async def c3(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait():
|
||||
result.append(3)
|
||||
return True
|
||||
|
||||
t1 = asyncio.create_task(c1(result))
|
||||
t2 = asyncio.create_task(c2(result))
|
||||
t3 = asyncio.create_task(c3(result))
|
||||
t1 = asyncio.create_task(c1(result))
|
||||
t2 = asyncio.create_task(c2(result))
|
||||
t3 = asyncio.create_task(c3(result))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
async with cond:
|
||||
cond.notify(1)
|
||||
await asyncio.sleep(1)
|
||||
self.assertEqual([1], result)
|
||||
|
||||
async with cond:
|
||||
cond.notify(1)
|
||||
cond.notify(2048)
|
||||
await asyncio.sleep(1)
|
||||
self.assertEqual([1, 2, 3], result)
|
||||
|
||||
self.assertTrue(t1.done())
|
||||
self.assertTrue(t1.result())
|
||||
self.assertTrue(t2.done())
|
||||
self.assertTrue(t2.result())
|
||||
self.assertTrue(t3.done())
|
||||
self.assertTrue(t3.result())
|
||||
|
||||
async def test_notify_all(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(1)
|
||||
return True
|
||||
|
||||
async def c2(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(2)
|
||||
return True
|
||||
|
||||
t1 = asyncio.create_task(c1(result))
|
||||
t2 = asyncio.create_task(c2(result))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
async with cond:
|
||||
cond.notify_all()
|
||||
await asyncio.sleep(1)
|
||||
self.assertEqual([1, 2], result)
|
||||
|
||||
self.assertTrue(t1.done())
|
||||
self.assertTrue(t1.result())
|
||||
self.assertTrue(t2.done())
|
||||
self.assertTrue(t2.result())
|
||||
|
||||
async def test_context_manager(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
self.assertFalse(cond.locked())
|
||||
async with cond:
|
||||
self.assertTrue(cond.locked())
|
||||
self.assertFalse(cond.locked())
|
||||
|
||||
async def test_timeout_in_block(self):
|
||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
||||
async with condition:
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(condition.wait(), timeout=0.5)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
|
||||
)
|
||||
async def test_cancelled_error_wakeup(self):
|
||||
# Test that a cancelled error, received when awaiting wakeup,
|
||||
# will be re-raised un-modified.
|
||||
wake = False
|
||||
raised = None
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
async def func():
|
||||
nonlocal raised
|
||||
async with cond:
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await cond.wait_for(lambda: wake)
|
||||
raised = err.exception
|
||||
raise raised
|
||||
|
||||
task = asyncio.create_task(func())
|
||||
await asyncio.sleep(0)
|
||||
# Task is waiting on the condition, cancel it there.
|
||||
task.cancel(msg="foo") # type: ignore[call-arg]
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await task
|
||||
self.assertEqual(err.exception.args, ("foo",))
|
||||
# We should have got the _same_ exception instance as the one
|
||||
# originally raised.
|
||||
self.assertIs(err.exception, raised)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
|
||||
)
|
||||
async def test_cancelled_error_re_aquire(self):
|
||||
# Test that a cancelled error, received when re-aquiring lock,
|
||||
# will be re-raised un-modified.
|
||||
wake = False
|
||||
raised = None
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
async def func():
|
||||
nonlocal raised
|
||||
async with cond:
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await cond.wait_for(lambda: wake)
|
||||
raised = err.exception
|
||||
raise raised
|
||||
|
||||
task = asyncio.create_task(func())
|
||||
await asyncio.sleep(0)
|
||||
# Task is waiting on the condition
|
||||
await cond.acquire()
|
||||
wake = True
|
||||
cond.notify()
|
||||
await asyncio.sleep(0)
|
||||
# Task is now trying to re-acquire the lock, cancel it there.
|
||||
task.cancel(msg="foo") # type: ignore[call-arg]
|
||||
cond.release()
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await task
|
||||
self.assertEqual(err.exception.args, ("foo",))
|
||||
# We should have got the _same_ exception instance as the one
|
||||
# originally raised.
|
||||
self.assertIs(err.exception, raised)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
|
||||
async def test_cancelled_wakeup(self):
|
||||
# Test that a task cancelled at the "same" time as it is woken
|
||||
# up as part of a Condition.notify() does not result in a lost wakeup.
|
||||
# This test simulates a cancel while the target task is awaiting initial
|
||||
# wakeup on the wakeup queue.
|
||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
||||
state = 0
|
||||
|
||||
async def consumer():
|
||||
nonlocal state
|
||||
async with condition:
|
||||
while True:
|
||||
await condition.wait_for(lambda: state != 0)
|
||||
if state < 0:
|
||||
return
|
||||
state -= 1
|
||||
|
||||
# create two consumers
|
||||
c = [asyncio.create_task(consumer()) for _ in range(2)]
|
||||
# wait for them to settle
|
||||
await asyncio.sleep(0.1)
|
||||
async with condition:
|
||||
# produce one item and wake up one
|
||||
state += 1
|
||||
condition.notify(1)
|
||||
|
||||
# Cancel it while it is awaiting to be run.
|
||||
# This cancellation could come from the outside
|
||||
c[0].cancel()
|
||||
|
||||
# now wait for the item to be consumed
|
||||
# if it doesn't means that our "notify" didn"t take hold.
|
||||
# because it raced with a cancel()
|
||||
try:
|
||||
async with asyncio.timeout(1):
|
||||
await condition.wait_for(lambda: state == 0)
|
||||
except TimeoutError:
|
||||
pass
|
||||
self.assertEqual(state, 0)
|
||||
|
||||
# clean up
|
||||
state = -1
|
||||
condition.notify_all()
|
||||
await c[1]
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
|
||||
async def test_cancelled_wakeup_relock(self):
|
||||
# Test that a task cancelled at the "same" time as it is woken
|
||||
# up as part of a Condition.notify() does not result in a lost wakeup.
|
||||
# This test simulates a cancel while the target task is acquiring the lock
|
||||
# again.
|
||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
||||
state = 0
|
||||
|
||||
async def consumer():
|
||||
nonlocal state
|
||||
async with condition:
|
||||
while True:
|
||||
await condition.wait_for(lambda: state != 0)
|
||||
if state < 0:
|
||||
return
|
||||
state -= 1
|
||||
|
||||
# create two consumers
|
||||
c = [asyncio.create_task(consumer()) for _ in range(2)]
|
||||
# wait for them to settle
|
||||
await asyncio.sleep(0.1)
|
||||
async with condition:
|
||||
# produce one item and wake up one
|
||||
state += 1
|
||||
condition.notify(1)
|
||||
|
||||
# now we sleep for a bit. This allows the target task to wake up and
|
||||
# settle on re-aquiring the lock
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
self.assertFalse(cond.locked())
|
||||
|
||||
# Cancel it while awaiting the lock
|
||||
# This cancel could come the outside.
|
||||
c[0].cancel()
|
||||
self.assertTrue(await cond.acquire())
|
||||
cond.notify()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.notify(2)
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1, 2], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1, 2, 3], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
self.assertTrue(t1.done())
|
||||
self.assertTrue(t1.result())
|
||||
self.assertTrue(t2.done())
|
||||
self.assertTrue(t2.result())
|
||||
self.assertTrue(t3.done())
|
||||
self.assertTrue(t3.result())
|
||||
|
||||
async def test_wait_cancel(self):
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
await cond.acquire()
|
||||
|
||||
wait = asyncio.create_task(cond.wait())
|
||||
asyncio.get_running_loop().call_soon(wait.cancel)
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await wait
|
||||
self.assertFalse(cond._waiters)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
async def test_wait_cancel_contested(self):
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
|
||||
await cond.acquire()
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
wait_task = asyncio.create_task(cond.wait())
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(cond.locked())
|
||||
|
||||
# Notify, but contest the lock before cancelling
|
||||
await cond.acquire()
|
||||
self.assertTrue(cond.locked())
|
||||
cond.notify()
|
||||
asyncio.get_running_loop().call_soon(wait_task.cancel)
|
||||
asyncio.get_running_loop().call_soon(cond.release)
|
||||
|
||||
# now wait for the item to be consumed
|
||||
# if it doesn't means that our "notify" didn"t take hold.
|
||||
# because it raced with a cancel()
|
||||
try:
|
||||
async with asyncio.timeout(1):
|
||||
await condition.wait_for(lambda: state == 0)
|
||||
except TimeoutError:
|
||||
await wait_task
|
||||
except asyncio.CancelledError:
|
||||
# Should not happen, since no cancellation points
|
||||
pass
|
||||
self.assertEqual(state, 0)
|
||||
|
||||
# clean up
|
||||
state = -1
|
||||
condition.notify_all()
|
||||
await c[1]
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
async def test_wait_cancel_after_notify(self):
|
||||
# See bpo-32841
|
||||
waited = False
|
||||
|
||||
class TestCondition(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_multiple_loops_notify(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
|
||||
def tmain(cond):
|
||||
async def atmain(cond):
|
||||
await asyncio.sleep(1)
|
||||
async def wait_on_cond():
|
||||
nonlocal waited
|
||||
async with cond:
|
||||
cond.notify(1)
|
||||
waited = True # Make sure this area was reached
|
||||
await cond.wait()
|
||||
|
||||
asyncio.run(atmain(cond))
|
||||
waiter = asyncio.create_task(wait_on_cond())
|
||||
await asyncio.sleep(0) # Start waiting
|
||||
|
||||
t = threading.Thread(target=tmain, args=(cond,))
|
||||
t.start()
|
||||
await cond.acquire()
|
||||
cond.notify()
|
||||
await asyncio.sleep(0) # Get to acquire()
|
||||
waiter.cancel()
|
||||
await asyncio.sleep(0) # Activate cancellation
|
||||
cond.release()
|
||||
await asyncio.sleep(0) # Cancellation should occur
|
||||
|
||||
async with cond:
|
||||
self.assertTrue(await cond.wait(30))
|
||||
t.join()
|
||||
self.assertTrue(waiter.cancelled())
|
||||
self.assertTrue(waited)
|
||||
|
||||
async def test_multiple_loops_notify_all(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
results = []
|
||||
async def test_wait_unacquired(self):
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
with self.assertRaises(RuntimeError):
|
||||
await cond.wait()
|
||||
|
||||
def tmain(cond, results):
|
||||
async def atmain(cond, results):
|
||||
await asyncio.sleep(1)
|
||||
async def test_wait_for(self):
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
presult = False
|
||||
|
||||
def predicate():
|
||||
return presult
|
||||
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait_for(predicate):
|
||||
result.append(1)
|
||||
cond.release()
|
||||
return True
|
||||
|
||||
t = asyncio.create_task(c1(result))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
await cond.acquire()
|
||||
cond.notify()
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
presult = True
|
||||
await cond.acquire()
|
||||
cond.notify()
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1], result)
|
||||
|
||||
self.assertTrue(t.done())
|
||||
self.assertTrue(t.result())
|
||||
|
||||
async def test_wait_for_unacquired(self):
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
|
||||
# predicate can return true immediately
|
||||
res = await cond.wait_for(lambda: [1, 2, 3])
|
||||
self.assertEqual([1, 2, 3], res)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
await cond.wait_for(lambda: False)
|
||||
|
||||
async def test_notify(self):
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
async with cond:
|
||||
res = await cond.wait(30)
|
||||
results.append(res)
|
||||
if await cond.wait():
|
||||
result.append(1)
|
||||
return True
|
||||
|
||||
asyncio.run(atmain(cond, results))
|
||||
async def c2(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(2)
|
||||
return True
|
||||
|
||||
nthreads = 5
|
||||
threads = []
|
||||
for _ in range(nthreads):
|
||||
threads.append(threading.Thread(target=tmain, args=(cond, results)))
|
||||
for t in threads:
|
||||
t.start()
|
||||
async def c3(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(3)
|
||||
return True
|
||||
|
||||
await asyncio.sleep(2)
|
||||
async with cond:
|
||||
cond.notify_all()
|
||||
t1 = asyncio.create_task(c1(result))
|
||||
t2 = asyncio.create_task(c2(result))
|
||||
t3 = asyncio.create_task(c3(result))
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
self.assertEqual(results, [True] * nthreads)
|
||||
async with cond:
|
||||
cond.notify(1)
|
||||
await asyncio.sleep(1)
|
||||
self.assertEqual([1], result)
|
||||
|
||||
async with cond:
|
||||
cond.notify(1)
|
||||
cond.notify(2048)
|
||||
await asyncio.sleep(1)
|
||||
self.assertEqual([1, 2, 3], result)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
self.assertTrue(t1.done())
|
||||
self.assertTrue(t1.result())
|
||||
self.assertTrue(t2.done())
|
||||
self.assertTrue(t2.result())
|
||||
self.assertTrue(t3.done())
|
||||
self.assertTrue(t3.result())
|
||||
|
||||
async def test_notify_all(self):
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(1)
|
||||
return True
|
||||
|
||||
async def c2(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(2)
|
||||
return True
|
||||
|
||||
t1 = asyncio.create_task(c1(result))
|
||||
t2 = asyncio.create_task(c2(result))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
async with cond:
|
||||
cond.notify_all()
|
||||
await asyncio.sleep(1)
|
||||
self.assertEqual([1, 2], result)
|
||||
|
||||
self.assertTrue(t1.done())
|
||||
self.assertTrue(t1.result())
|
||||
self.assertTrue(t2.done())
|
||||
self.assertTrue(t2.result())
|
||||
|
||||
async def test_context_manager(self):
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
self.assertFalse(cond.locked())
|
||||
async with cond:
|
||||
self.assertTrue(cond.locked())
|
||||
self.assertFalse(cond.locked())
|
||||
|
||||
async def test_timeout_in_block(self):
|
||||
condition = _async_create_condition(_async_create_lock())
|
||||
async with condition:
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(condition.wait(), timeout=0.5)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
|
||||
)
|
||||
async def test_cancelled_error_wakeup(self):
|
||||
# Test that a cancelled error, received when awaiting wakeup,
|
||||
# will be re-raised un-modified.
|
||||
wake = False
|
||||
raised = None
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
|
||||
async def func():
|
||||
nonlocal raised
|
||||
async with cond:
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await cond.wait_for(lambda: wake)
|
||||
raised = err.exception
|
||||
raise raised
|
||||
|
||||
task = asyncio.create_task(func())
|
||||
await asyncio.sleep(0)
|
||||
# Task is waiting on the condition, cancel it there.
|
||||
task.cancel(msg="foo") # type: ignore[call-arg]
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await task
|
||||
self.assertEqual(err.exception.args, ("foo",))
|
||||
# We should have got the _same_ exception instance as the one
|
||||
# originally raised.
|
||||
self.assertIs(err.exception, raised)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
|
||||
)
|
||||
async def test_cancelled_error_re_aquire(self):
|
||||
# Test that a cancelled error, received when re-aquiring lock,
|
||||
# will be re-raised un-modified.
|
||||
wake = False
|
||||
raised = None
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
|
||||
async def func():
|
||||
nonlocal raised
|
||||
async with cond:
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await cond.wait_for(lambda: wake)
|
||||
raised = err.exception
|
||||
raise raised
|
||||
|
||||
task = asyncio.create_task(func())
|
||||
await asyncio.sleep(0)
|
||||
# Task is waiting on the condition
|
||||
await cond.acquire()
|
||||
wake = True
|
||||
cond.notify()
|
||||
await asyncio.sleep(0)
|
||||
# Task is now trying to re-acquire the lock, cancel it there.
|
||||
task.cancel(msg="foo") # type: ignore[call-arg]
|
||||
cond.release()
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await task
|
||||
self.assertEqual(err.exception.args, ("foo",))
|
||||
# We should have got the _same_ exception instance as the one
|
||||
# originally raised.
|
||||
self.assertIs(err.exception, raised)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
|
||||
async def test_cancelled_wakeup(self):
|
||||
# Test that a task cancelled at the "same" time as it is woken
|
||||
# up as part of a Condition.notify() does not result in a lost wakeup.
|
||||
# This test simulates a cancel while the target task is awaiting initial
|
||||
# wakeup on the wakeup queue.
|
||||
condition = _async_create_condition(_async_create_lock())
|
||||
state = 0
|
||||
|
||||
async def consumer():
|
||||
nonlocal state
|
||||
async with condition:
|
||||
while True:
|
||||
await condition.wait_for(lambda: state != 0)
|
||||
if state < 0:
|
||||
return
|
||||
state -= 1
|
||||
|
||||
# create two consumers
|
||||
c = [asyncio.create_task(consumer()) for _ in range(2)]
|
||||
# wait for them to settle
|
||||
await asyncio.sleep(0.1)
|
||||
async with condition:
|
||||
# produce one item and wake up one
|
||||
state += 1
|
||||
condition.notify(1)
|
||||
|
||||
# Cancel it while it is awaiting to be run.
|
||||
# This cancellation could come from the outside
|
||||
c[0].cancel()
|
||||
|
||||
# now wait for the item to be consumed
|
||||
# if it doesn't means that our "notify" didn"t take hold.
|
||||
# because it raced with a cancel()
|
||||
try:
|
||||
async with asyncio.timeout(1):
|
||||
await condition.wait_for(lambda: state == 0)
|
||||
except TimeoutError:
|
||||
pass
|
||||
self.assertEqual(state, 0)
|
||||
|
||||
# clean up
|
||||
state = -1
|
||||
condition.notify_all()
|
||||
await c[1]
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
|
||||
async def test_cancelled_wakeup_relock(self):
|
||||
# Test that a task cancelled at the "same" time as it is woken
|
||||
# up as part of a Condition.notify() does not result in a lost wakeup.
|
||||
# This test simulates a cancel while the target task is acquiring the lock
|
||||
# again.
|
||||
condition = _async_create_condition(_async_create_lock())
|
||||
state = 0
|
||||
|
||||
async def consumer():
|
||||
nonlocal state
|
||||
async with condition:
|
||||
while True:
|
||||
await condition.wait_for(lambda: state != 0)
|
||||
if state < 0:
|
||||
return
|
||||
state -= 1
|
||||
|
||||
# create two consumers
|
||||
c = [asyncio.create_task(consumer()) for _ in range(2)]
|
||||
# wait for them to settle
|
||||
await asyncio.sleep(0.1)
|
||||
async with condition:
|
||||
# produce one item and wake up one
|
||||
state += 1
|
||||
condition.notify(1)
|
||||
|
||||
# now we sleep for a bit. This allows the target task to wake up and
|
||||
# settle on re-aquiring the lock
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Cancel it while awaiting the lock
|
||||
# This cancel could come the outside.
|
||||
c[0].cancel()
|
||||
|
||||
# now wait for the item to be consumed
|
||||
# if it doesn't means that our "notify" didn"t take hold.
|
||||
# because it raced with a cancel()
|
||||
try:
|
||||
async with asyncio.timeout(1):
|
||||
await condition.wait_for(lambda: state == 0)
|
||||
except TimeoutError:
|
||||
pass
|
||||
self.assertEqual(state, 0)
|
||||
|
||||
# clean up
|
||||
state = -1
|
||||
condition.notify_all()
|
||||
await c[1]
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -52,22 +52,16 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest):
|
||||
listener: EventListener
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_connection
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
def setUpClass(cls) -> None:
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(
|
||||
event_listeners=[cls.listener], retryWrites=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
@async_client_context.require_connection
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.listener.reset()
|
||||
await super().asyncTearDown()
|
||||
self.client = await self.async_rs_or_single_client(
|
||||
event_listeners=[self.listener], retryWrites=False
|
||||
)
|
||||
|
||||
async def test_started_simple(self):
|
||||
await self.client.pymongo_test.command("ping")
|
||||
@ -1140,26 +1134,23 @@ class AsyncTestGlobalListener(AsyncIntegrationTest):
|
||||
saved_listeners: Any
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_connection
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
def setUpClass(cls) -> None:
|
||||
cls.listener = OvertCommandListener()
|
||||
# We plan to call register(), which internally modifies _LISTENERS.
|
||||
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
|
||||
monitoring.register(cls.listener)
|
||||
cls.client = await cls.unmanaged_async_single_client()
|
||||
# Get one (authenticated) socket in the pool.
|
||||
await cls.client.pymongo_test.command("ping")
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
monitoring._LISTENERS = cls.saved_listeners
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
@async_client_context.require_connection
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.listener.reset()
|
||||
self.client = await self.async_single_client()
|
||||
# Get one (authenticated) socket in the pool.
|
||||
await self.client.pymongo_test.command("ping")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
monitoring._LISTENERS = cls.saved_listeners
|
||||
|
||||
async def test_simple(self):
|
||||
await self.client.pymongo_test.command("ping")
|
||||
|
||||
@ -132,34 +132,27 @@ class IgnoreDeprecationsTest(AsyncIntegrationTest):
|
||||
RUN_ON_SERVERLESS = True
|
||||
deprecation_filter: DeprecationFilter
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.deprecation_filter = DeprecationFilter()
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.deprecation_filter = DeprecationFilter()
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.deprecation_filter.stop()
|
||||
await super()._tearDown_class()
|
||||
async def asyncTearDown(self) -> None:
|
||||
self.deprecation_filter.stop()
|
||||
|
||||
|
||||
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
|
||||
knobs: client_knobs
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(retryWrites=True)
|
||||
cls.db = cls.client.pymongo_test
|
||||
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
self.knobs.enable()
|
||||
self.client = await self.async_rs_or_single_client(retryWrites=True)
|
||||
self.db = self.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
async def asyncTearDown(self) -> None:
|
||||
self.knobs.disable()
|
||||
|
||||
@async_client_context.require_no_standalone
|
||||
async def test_actionable_error_message(self):
|
||||
@ -180,26 +173,18 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
listener: OvertCommandListener
|
||||
knobs: client_knobs
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_no_mmap
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(
|
||||
retryWrites=True, event_listeners=[cls.listener]
|
||||
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
self.knobs.enable()
|
||||
self.listener = OvertCommandListener()
|
||||
self.client = await self.async_rs_or_single_client(
|
||||
retryWrites=True, event_listeners=[self.listener]
|
||||
)
|
||||
cls.db = cls.client.pymongo_test
|
||||
self.db = self.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
if async_client_context.is_rs and async_client_context.test_commands_enabled:
|
||||
await self.client.admin.command(
|
||||
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")])
|
||||
@ -210,6 +195,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
await self.client.admin.command(
|
||||
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
|
||||
)
|
||||
self.knobs.disable()
|
||||
|
||||
async def test_supported_single_statement_no_retry(self):
|
||||
listener = OvertCommandListener()
|
||||
@ -438,13 +424,12 @@ class TestWriteConcernError(AsyncIntegrationTest):
|
||||
RUN_ON_SERVERLESS = True
|
||||
fail_insert: dict
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_replica_set
|
||||
@async_client_context.require_no_mmap
|
||||
@async_client_context.require_failCommand_fail_point
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.fail_insert = {
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.fail_insert = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 2},
|
||||
"data": {
|
||||
|
||||
@ -38,7 +38,6 @@ from test.utils import (
|
||||
ExceptionCatchingThread,
|
||||
OvertCommandListener,
|
||||
async_wait_until,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
from bson import DBRef
|
||||
@ -83,36 +82,27 @@ class TestSession(AsyncIntegrationTest):
|
||||
client2: AsyncMongoClient
|
||||
sensitive_commands: Set[str]
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_sessions
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
# Create a second client so we can make sure clients cannot share
|
||||
# sessions.
|
||||
cls.client2 = await cls.unmanaged_async_rs_or_single_client()
|
||||
self.client2 = await self.async_rs_or_single_client()
|
||||
|
||||
# Redact no commands, so we can test user-admin commands have "lsid".
|
||||
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
|
||||
self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
|
||||
monitoring._SENSITIVE_COMMANDS.clear()
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands)
|
||||
await cls.client2.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.listener = SessionTestListener()
|
||||
self.session_checker_listener = SessionTestListener()
|
||||
self.client = await self.async_rs_or_single_client(
|
||||
event_listeners=[self.listener, self.session_checker_listener]
|
||||
)
|
||||
self.addAsyncCleanup(self.client.close)
|
||||
self.db = self.client.pymongo_test
|
||||
self.initial_lsids = {s["id"] for s in session_ids(self.client)}
|
||||
|
||||
async def asyncTearDown(self):
|
||||
"""All sessions used in the test must be returned to the pool."""
|
||||
monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands)
|
||||
await self.client.drop_database("pymongo_test")
|
||||
used_lsids = self.initial_lsids.copy()
|
||||
for event in self.session_checker_listener.started_events:
|
||||
@ -122,6 +112,8 @@ class TestSession(AsyncIntegrationTest):
|
||||
current_lsids = {s["id"] for s in session_ids(self.client)}
|
||||
self.assertLessEqual(used_lsids, current_lsids)
|
||||
|
||||
await super().asyncTearDown()
|
||||
|
||||
async def _test_ops(self, client, *ops):
|
||||
listener = client.options.event_listeners[0]
|
||||
|
||||
@ -833,18 +825,11 @@ class TestCausalConsistency(AsyncUnitTest):
|
||||
listener: SessionTestListener
|
||||
client: AsyncMongoClient
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
cls.listener = SessionTestListener()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.close()
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.listener = SessionTestListener()
|
||||
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
|
||||
|
||||
@async_client_context.require_no_standalone
|
||||
async def test_core(self):
|
||||
|
||||
@ -26,7 +26,7 @@ sys.path[0:0] = [""]
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||
from test.utils import (
|
||||
OvertCommandListener,
|
||||
wait_until,
|
||||
async_wait_until,
|
||||
)
|
||||
from typing import List
|
||||
|
||||
@ -162,7 +162,7 @@ class TestTransactions(AsyncTransactionsBase):
|
||||
client = await self.async_rs_client(
|
||||
async_client_context.mongos_seeds(), localThresholdMS=1000
|
||||
)
|
||||
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
||||
await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
||||
coll = client.test.test
|
||||
# Create the collection.
|
||||
await coll.insert_one({})
|
||||
@ -191,7 +191,7 @@ class TestTransactions(AsyncTransactionsBase):
|
||||
client = await self.async_rs_client(
|
||||
async_client_context.mongos_seeds(), localThresholdMS=1000
|
||||
)
|
||||
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
||||
await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
||||
coll = client.test.test
|
||||
# Create the collection.
|
||||
await coll.insert_one({})
|
||||
@ -403,21 +403,12 @@ class PatchSessionTimeout:
|
||||
|
||||
|
||||
class TestTransactionsConvenientAPI(AsyncTransactionsBase):
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.mongos_clients = []
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.mongos_clients = []
|
||||
if async_client_context.supports_transactions():
|
||||
for address in async_client_context.mongoses:
|
||||
cls.mongos_clients.append(
|
||||
await cls.unmanaged_async_single_client("{}:{}".format(*address))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
for client in cls.mongos_clients:
|
||||
await client.close()
|
||||
await super()._tearDown_class()
|
||||
self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address)))
|
||||
|
||||
async def _set_fail_point(self, client, command_args):
|
||||
cmd = {"configureFailPoint": "failCommand"}
|
||||
|
||||
@ -50,6 +50,7 @@ from test.unified_format_shared import (
|
||||
)
|
||||
from test.utils import (
|
||||
async_get_pool,
|
||||
async_wait_until,
|
||||
camel_to_snake,
|
||||
camel_to_snake_args,
|
||||
parse_spec_options,
|
||||
@ -304,7 +305,6 @@ class EntityMapUtil:
|
||||
kwargs["h"] = uri
|
||||
client = await self.test.async_rs_or_single_client(**kwargs)
|
||||
self[spec["id"]] = client
|
||||
self.test.addAsyncCleanup(client.close)
|
||||
return
|
||||
elif entity_type == "database":
|
||||
client = self[spec["client"]]
|
||||
@ -479,33 +479,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
await db.create_collection(coll_name, write_concern=wc, **opts)
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
# super call creates internal client cls.client
|
||||
await super()._setup_class()
|
||||
# process file-level runOnRequirements
|
||||
run_on_spec = cls.TEST_SPEC.get("runOnRequirements", [])
|
||||
if not await cls.should_run_on(run_on_spec):
|
||||
raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied")
|
||||
|
||||
# add any special-casing for skipping tests here
|
||||
if async_client_context.storage_engine == "mmapv1":
|
||||
if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str(
|
||||
cls.TEST_PATH
|
||||
):
|
||||
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
|
||||
|
||||
# Handle mongos_clients for transactions tests.
|
||||
cls.mongos_clients = []
|
||||
if (
|
||||
async_client_context.supports_transactions()
|
||||
and not async_client_context.load_balancer
|
||||
and not async_client_context.serverless
|
||||
):
|
||||
for address in async_client_context.mongoses:
|
||||
cls.mongos_clients.append(
|
||||
await cls.unmanaged_async_single_client("{}:{}".format(*address))
|
||||
)
|
||||
|
||||
def setUpClass(cls) -> None:
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(
|
||||
heartbeat_frequency=0.1,
|
||||
@ -516,17 +490,36 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
cls.knobs.enable()
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
def tearDownClass(cls) -> None:
|
||||
cls.knobs.disable()
|
||||
for client in cls.mongos_clients:
|
||||
await client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
# super call creates internal client cls.client
|
||||
await super().asyncSetUp()
|
||||
# process file-level runOnRequirements
|
||||
run_on_spec = self.TEST_SPEC.get("runOnRequirements", [])
|
||||
if not await self.should_run_on(run_on_spec):
|
||||
raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied")
|
||||
|
||||
# add any special-casing for skipping tests here
|
||||
if async_client_context.storage_engine == "mmapv1":
|
||||
if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str(
|
||||
self.TEST_PATH
|
||||
):
|
||||
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
|
||||
|
||||
# Handle mongos_clients for transactions tests.
|
||||
self.mongos_clients = []
|
||||
if (
|
||||
async_client_context.supports_transactions()
|
||||
and not async_client_context.load_balancer
|
||||
and not async_client_context.serverless
|
||||
):
|
||||
for address in async_client_context.mongoses:
|
||||
self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address)))
|
||||
|
||||
# process schemaVersion
|
||||
# note: we check major schema version during class generation
|
||||
# note: we do this here because we cannot run assertions in setUpClass
|
||||
version = Version.from_string(self.TEST_SPEC["schemaVersion"])
|
||||
self.assertLessEqual(
|
||||
version,
|
||||
@ -1036,7 +1029,6 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
)
|
||||
|
||||
client = await self.async_single_client("{}:{}".format(*session._pinned_address))
|
||||
self.addAsyncCleanup(client.close)
|
||||
await self.__set_fail_point(client=client, command_args=spec["failPoint"])
|
||||
|
||||
async def _testOperation_createEntities(self, spec):
|
||||
@ -1137,13 +1129,13 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
client, event, count = spec["client"], spec["event"], spec["count"]
|
||||
self.assertEqual(self._event_count(client, event), count, f"expected {count} not {event!r}")
|
||||
|
||||
def _testOperation_waitForEvent(self, spec):
|
||||
async def _testOperation_waitForEvent(self, spec):
|
||||
"""Run the waitForEvent test operation.
|
||||
|
||||
Wait for a number of events to be published, or fail.
|
||||
"""
|
||||
client, event, count = spec["client"], spec["event"], spec["count"]
|
||||
wait_until(
|
||||
await async_wait_until(
|
||||
lambda: self._event_count(client, event) >= count,
|
||||
f"find {count} {event} event(s)",
|
||||
)
|
||||
|
||||
@ -249,30 +249,22 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
knobs: client_knobs
|
||||
listener: EventListener
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.mongos_clients = []
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.mongos_clients = []
|
||||
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
for client in cls.mongos_clients:
|
||||
await client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
self.knobs.enable()
|
||||
self.targets = {}
|
||||
self.listener = None # type: ignore
|
||||
self.pool_listener = None
|
||||
self.server_listener = None
|
||||
self.maxDiff = None
|
||||
|
||||
async def asyncTearDown(self) -> None:
|
||||
self.knobs.disable()
|
||||
|
||||
async def _set_fail_point(self, client, command_args):
|
||||
cmd = SON([("configureFailPoint", "failCommand")])
|
||||
cmd.update(command_args)
|
||||
@ -700,8 +692,6 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
self.listener = listener
|
||||
self.pool_listener = pool_listener
|
||||
self.server_listener = server_listener
|
||||
# Close the client explicitly to avoid having too many threads open.
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
# Create session0 and session1.
|
||||
sessions = {}
|
||||
|
||||
@ -20,7 +20,7 @@ def event_loop_policy():
|
||||
return asyncio.get_event_loop_policy()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@pytest.fixture(scope="package", autouse=True)
|
||||
def test_setup_and_teardown():
|
||||
setup()
|
||||
yield
|
||||
|
||||
@ -107,4 +107,4 @@ Automation
|
||||
At MongoDB, Inc. we use a continuous integration job that tests each
|
||||
combination in the matrix. The job starts up Apache, starts a single server
|
||||
or replica set, and runs ``test_client.py`` with the proper arguments.
|
||||
See `run-mod-wsgi-tests.sh <https://github.com/mongodb/mongo-python-driver/blob/master/.evergreen/run-mod-wsgi-tests.sh>`_
|
||||
See `run-mod-wsgi-tests.sh <https://github.com/mongodb/mongo-python-driver/blob/master/.evergreen/scripts/run-mod-wsgi-tests.sh>`_
|
||||
|
||||
@ -42,15 +42,11 @@ class BulkTestBase(IntegrationTest):
|
||||
coll: Collection
|
||||
coll_w0: Collection
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.coll = cls.db.test
|
||||
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.coll = self.db.test
|
||||
self.coll.drop()
|
||||
self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0))
|
||||
|
||||
def assertEqualResponse(self, expected, actual):
|
||||
"""Compare response from bulk.execute() to expected response."""
|
||||
@ -785,12 +781,8 @@ class TestBulk(BulkTestBase):
|
||||
|
||||
|
||||
class BulkAuthorizationTestBase(BulkTestBase):
|
||||
@classmethod
|
||||
@client_context.require_auth
|
||||
@client_context.require_no_api_version
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
client_context.create_user(self.db.name, "readonly", "pw", ["read"])
|
||||
@ -935,21 +927,19 @@ class TestBulkWriteConcern(BulkTestBase):
|
||||
w: Optional[int]
|
||||
secondary: MongoClient
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.w = client_context.w
|
||||
cls.secondary = None
|
||||
if cls.w is not None and cls.w > 1:
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.w = client_context.w
|
||||
self.secondary = None
|
||||
if self.w is not None and self.w > 1:
|
||||
for member in (client_context.hello)["hosts"]:
|
||||
if member != (client_context.hello)["primary"]:
|
||||
cls.secondary = cls.unmanaged_single_client(*partition_node(member))
|
||||
self.secondary = self.single_client(*partition_node(member))
|
||||
break
|
||||
|
||||
@classmethod
|
||||
def async_tearDownClass(cls):
|
||||
if cls.secondary:
|
||||
cls.secondary.close()
|
||||
def tearDown(self):
|
||||
if self.secondary:
|
||||
self.secondary.close()
|
||||
|
||||
def cause_wtimeout(self, requests, ordered):
|
||||
if not client_context.test_commands_enabled:
|
||||
|
||||
@ -820,18 +820,16 @@ class ProseSpecTestsMixin:
|
||||
class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
|
||||
dbs: list
|
||||
|
||||
@classmethod
|
||||
@client_context.require_version_min(4, 0, 0, -1)
|
||||
@client_context.require_change_streams
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.dbs = [cls.db, cls.client.pymongo_test_2]
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.dbs = [self.db, self.client.pymongo_test_2]
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
for db in cls.dbs:
|
||||
cls.client.drop_database(db)
|
||||
super()._tearDown_class()
|
||||
def tearDown(self):
|
||||
for db in self.dbs:
|
||||
self.client.drop_database(db)
|
||||
super().tearDown()
|
||||
|
||||
def change_stream_with_client(self, client, *args, **kwargs):
|
||||
return client.watch(*args, **kwargs)
|
||||
@ -882,11 +880,10 @@ class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
|
||||
|
||||
|
||||
class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
|
||||
@classmethod
|
||||
@client_context.require_version_min(4, 0, 0, -1)
|
||||
@client_context.require_change_streams
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
||||
def change_stream_with_client(self, client, *args, **kwargs):
|
||||
return client[self.db.name].watch(*args, **kwargs)
|
||||
@ -968,12 +965,9 @@ class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
|
||||
|
||||
|
||||
class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin):
|
||||
@classmethod
|
||||
@client_context.require_change_streams
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Use a new collection for each test.
|
||||
self.watched_collection().drop()
|
||||
self.watched_collection().insert_one({})
|
||||
@ -1111,20 +1105,11 @@ class TestAllLegacyScenarios(IntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = True
|
||||
listener: AllowListEventListener
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.listener = AllowListEventListener("aggregate", "getMore")
|
||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.client.close()
|
||||
super()._tearDown_class()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.listener = AllowListEventListener("aggregate", "getMore")
|
||||
self.client = self.rs_or_single_client(event_listeners=[self.listener])
|
||||
self.listener.reset()
|
||||
|
||||
def setUpCluster(self, scenario_dict):
|
||||
|
||||
@ -129,13 +129,8 @@ class ClientUnitTest(UnitTest):
|
||||
|
||||
client: MongoClient
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
cls.client = cls.unmanaged_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.client.close()
|
||||
def setUp(self) -> None:
|
||||
self.client = self.rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
@ -1039,14 +1034,21 @@ class TestClient(IntegrationTest):
|
||||
self.assertFalse(client._topology._opened)
|
||||
|
||||
# Ensure kill cursors thread has not been started.
|
||||
kc_thread = client._kill_cursors_executor._thread
|
||||
self.assertFalse(kc_thread and kc_thread.is_alive())
|
||||
|
||||
if _IS_SYNC:
|
||||
kc_thread = client._kill_cursors_executor._thread
|
||||
self.assertFalse(kc_thread and kc_thread.is_alive())
|
||||
else:
|
||||
kc_task = client._kill_cursors_executor._task
|
||||
self.assertFalse(kc_task and not kc_task.done())
|
||||
# Using the client should open topology and start the thread.
|
||||
client.admin.command("ping")
|
||||
self.assertTrue(client._topology._opened)
|
||||
kc_thread = client._kill_cursors_executor._thread
|
||||
self.assertTrue(kc_thread and kc_thread.is_alive())
|
||||
if _IS_SYNC:
|
||||
kc_thread = client._kill_cursors_executor._thread
|
||||
self.assertTrue(kc_thread and kc_thread.is_alive())
|
||||
else:
|
||||
kc_task = client._kill_cursors_executor._task
|
||||
self.assertTrue(kc_task and not kc_task.done())
|
||||
|
||||
def test_close_does_not_open_servers(self):
|
||||
client = self.rs_client(connect=False)
|
||||
@ -1241,6 +1243,7 @@ class TestClient(IntegrationTest):
|
||||
def test_server_selection_timeout(self):
|
||||
client = MongoClient(serverSelectionTimeoutMS=100, connect=False)
|
||||
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
||||
client.close()
|
||||
|
||||
client = MongoClient(serverSelectionTimeoutMS=0, connect=False)
|
||||
|
||||
@ -1251,16 +1254,20 @@ class TestClient(IntegrationTest):
|
||||
self.assertRaises(
|
||||
ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False
|
||||
)
|
||||
client.close()
|
||||
|
||||
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False)
|
||||
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
||||
client.close()
|
||||
|
||||
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
|
||||
self.assertAlmostEqual(0, client.options.server_selection_timeout)
|
||||
client.close()
|
||||
|
||||
# Test invalid timeout in URI ignored and set to default.
|
||||
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
|
||||
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
||||
client.close()
|
||||
|
||||
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
|
||||
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
||||
|
||||
@ -25,7 +25,7 @@ _IS_SYNC = True
|
||||
|
||||
class TestClientContext(UnitTest):
|
||||
def test_must_connect(self):
|
||||
if "PYMONGO_MUST_CONNECT" not in os.environ:
|
||||
if not os.environ.get("PYMONGO_MUST_CONNECT"):
|
||||
raise SkipTest("PYMONGO_MUST_CONNECT is not set")
|
||||
|
||||
self.assertTrue(
|
||||
@ -37,7 +37,7 @@ class TestClientContext(UnitTest):
|
||||
)
|
||||
|
||||
def test_serverless(self):
|
||||
if "TEST_SERVERLESS" not in os.environ:
|
||||
if not os.environ.get("TEST_SERVERLESS"):
|
||||
raise SkipTest("TEST_SERVERLESS is not set")
|
||||
|
||||
self.assertTrue(
|
||||
@ -47,7 +47,7 @@ class TestClientContext(UnitTest):
|
||||
)
|
||||
|
||||
def test_enableTestCommands_is_disabled(self):
|
||||
if "PYMONGO_DISABLE_TEST_COMMANDS" not in os.environ:
|
||||
if not os.environ.get("PYMONGO_DISABLE_TEST_COMMANDS"):
|
||||
raise SkipTest("PYMONGO_DISABLE_TEST_COMMANDS is not set")
|
||||
|
||||
self.assertFalse(
|
||||
@ -56,7 +56,7 @@ class TestClientContext(UnitTest):
|
||||
)
|
||||
|
||||
def test_setdefaultencoding_worked(self):
|
||||
if "SETDEFAULTENCODING" not in os.environ:
|
||||
if not os.environ.get("SETDEFAULTENCODING"):
|
||||
raise SkipTest("SETDEFAULTENCODING is not set")
|
||||
|
||||
self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"])
|
||||
|
||||
@ -97,26 +97,19 @@ class TestCollation(IntegrationTest):
|
||||
warn_context: Any
|
||||
collation: Collation
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.db = cls.client.pymongo_test
|
||||
cls.collation = Collation("en_US")
|
||||
cls.warn_context = warnings.catch_warnings()
|
||||
cls.warn_context.__enter__()
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.listener = OvertCommandListener()
|
||||
self.client = self.rs_or_single_client(event_listeners=[self.listener])
|
||||
self.db = self.client.pymongo_test
|
||||
self.collation = Collation("en_US")
|
||||
self.warn_context = warnings.catch_warnings()
|
||||
self.warn_context.__enter__()
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.warn_context.__exit__()
|
||||
cls.warn_context = None
|
||||
cls.client.close()
|
||||
super()._tearDown_class()
|
||||
|
||||
def tearDown(self):
|
||||
def tearDown(self) -> None:
|
||||
self.warn_context.__exit__()
|
||||
self.warn_context = None
|
||||
self.listener.reset()
|
||||
super().tearDown()
|
||||
|
||||
|
||||
@ -87,14 +87,10 @@ class TestCollectionNoConnect(UnitTest):
|
||||
db: Database
|
||||
client: MongoClient
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
cls.client = MongoClient(connect=False)
|
||||
cls.db = cls.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.client.close()
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.client = self.simple_client(connect=False)
|
||||
self.db = self.client.pymongo_test
|
||||
|
||||
def test_collection(self):
|
||||
self.assertRaises(TypeError, Collection, self.db, 5)
|
||||
@ -164,27 +160,14 @@ class TestCollectionNoConnect(UnitTest):
|
||||
class TestCollection(IntegrationTest):
|
||||
w: int
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.w = client_context.w # type: ignore
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine]
|
||||
else:
|
||||
asyncio.run(cls.async_tearDownClass())
|
||||
|
||||
@classmethod
|
||||
def async_tearDownClass(cls):
|
||||
cls.db.drop_collection("test_large_limit")
|
||||
|
||||
def setUp(self):
|
||||
self.db.test.drop()
|
||||
super().setUp()
|
||||
self.w = client_context.w # type: ignore
|
||||
|
||||
def tearDown(self):
|
||||
self.db.test.drop()
|
||||
self.db.drop_collection("test_large_limit")
|
||||
super().tearDown()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def write_concern_collection(self):
|
||||
@ -1010,7 +993,10 @@ class TestCollection(IntegrationTest):
|
||||
db.test.insert_one({"y": 1}, bypass_document_validation=True)
|
||||
db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True)
|
||||
|
||||
wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document")
|
||||
def predicate():
|
||||
return db_w0.test.find_one({"x": 1})
|
||||
|
||||
wait_until(predicate, "find w:0 replaced document")
|
||||
|
||||
def test_update_bypass_document_validation(self):
|
||||
db = self.db
|
||||
|
||||
@ -19,7 +19,12 @@ import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test import (
|
||||
IntegrationTest,
|
||||
client_context,
|
||||
reset_client_context,
|
||||
unittest,
|
||||
)
|
||||
from test.helpers import repl_set_step_down
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
@ -39,29 +44,19 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
|
||||
listener: CMAPListener
|
||||
coll: Collection
|
||||
|
||||
@classmethod
|
||||
@client_context.require_replica_set
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.listener = CMAPListener()
|
||||
cls.client = cls.unmanaged_rs_or_single_client(
|
||||
event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500
|
||||
def setUp(self):
|
||||
self.listener = CMAPListener()
|
||||
self.client = self.rs_or_single_client(
|
||||
event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500
|
||||
)
|
||||
|
||||
# Ensure connections to all servers in replica set. This is to test
|
||||
# that the is_writable flag is properly updated for connections that
|
||||
# survive a replica set election.
|
||||
ensure_all_connected(cls.client)
|
||||
cls.listener.reset()
|
||||
|
||||
cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority"))
|
||||
cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority"))
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.client.close()
|
||||
|
||||
def setUp(self):
|
||||
ensure_all_connected(self.client)
|
||||
self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority"))
|
||||
self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority"))
|
||||
# Note that all ops use same write-concern as self.db (majority).
|
||||
self.db.drop_collection("step-down")
|
||||
self.db.create_collection("step-down")
|
||||
|
||||
@ -56,6 +56,9 @@ class TestCreateEntities(IntegrationTest):
|
||||
self.assertGreater(len(final_entity_map["events1"]), 0)
|
||||
for event in final_entity_map["events1"]:
|
||||
self.assertIn("PoolCreatedEvent", event["name"])
|
||||
if self.scenario_runner.mongos_clients:
|
||||
for client in self.scenario_runner.mongos_clients:
|
||||
client.close()
|
||||
|
||||
def test_store_all_others_as_entities(self):
|
||||
self.scenario_runner = UnifiedSpecTestMixinV1()
|
||||
@ -122,6 +125,9 @@ class TestCreateEntities(IntegrationTest):
|
||||
self.assertEqual(entity_map["failures"], [])
|
||||
self.assertEqual(entity_map["successes"], 2)
|
||||
self.assertEqual(entity_map["iterations"], 5)
|
||||
if self.scenario_runner.mongos_clients:
|
||||
for client in self.scenario_runner.mongos_clients:
|
||||
client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1636,10 +1636,6 @@ class TestRawBatchCursor(IntegrationTest):
|
||||
|
||||
|
||||
class TestRawBatchCommandCursor(IntegrationTest):
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
|
||||
def test_aggregate_raw(self):
|
||||
c = self.db.test
|
||||
c.drop()
|
||||
|
||||
@ -633,6 +633,7 @@ class TestTypeRegistry(unittest.TestCase):
|
||||
|
||||
class TestCollectionWCustomType(IntegrationTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db.test.drop()
|
||||
|
||||
def tearDown(self):
|
||||
@ -754,6 +755,7 @@ class TestCollectionWCustomType(IntegrationTest):
|
||||
|
||||
class TestGridFileCustomType(IntegrationTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db.drop_collection("fs.files")
|
||||
self.db.drop_collection("fs.chunks")
|
||||
|
||||
@ -917,11 +919,10 @@ class ChangeStreamsWCustomTypesTestMixin:
|
||||
|
||||
|
||||
class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
||||
@classmethod
|
||||
@client_context.require_change_streams
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.db.test.delete_many({})
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db.test.delete_many({})
|
||||
|
||||
def tearDown(self):
|
||||
self.input_target.drop()
|
||||
@ -935,12 +936,11 @@ class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCus
|
||||
|
||||
|
||||
class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
||||
@classmethod
|
||||
@client_context.require_version_min(4, 0, 0)
|
||||
@client_context.require_change_streams
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.db.test.delete_many({})
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db.test.delete_many({})
|
||||
|
||||
def tearDown(self):
|
||||
self.input_target.drop()
|
||||
@ -954,12 +954,11 @@ class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCusto
|
||||
|
||||
|
||||
class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
||||
@classmethod
|
||||
@client_context.require_version_min(4, 0, 0)
|
||||
@client_context.require_change_streams
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.db.test.delete_many({})
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db.test.delete_many({})
|
||||
|
||||
def tearDown(self):
|
||||
self.input_target.drop()
|
||||
|
||||
@ -709,6 +709,7 @@ class TestDatabase(IntegrationTest):
|
||||
|
||||
class TestDatabaseAggregation(IntegrationTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.pipeline: List[Mapping[str, Any]] = [
|
||||
{"$listLocalSessions": {}},
|
||||
{"$limit": 1},
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user