Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2024-12-02 12:22:41 -06:00
commit c14943ab72
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
119 changed files with 2479 additions and 2515 deletions

0
.evergreen/combine-coverage.sh Normal file → Executable file
View File

File diff suppressed because it is too large Load Diff

2
.evergreen/hatch.sh Normal file → Executable file
View File

@ -29,7 +29,7 @@ else # Set up virtualenv before installing hatch
# Ensure hatch does not write to user or global locations.
touch hatch_config.toml
HATCH_CONFIG=$(pwd)/hatch_config.toml
if [ "Windows_NT" = "$OS" ]; then # Magic variable in cygwin
if [ "Windows_NT" = "${OS:-}" ]; then # Magic variable in cygwin
HATCH_CONFIG=$(cygpath -m "$HATCH_CONFIG")
fi
export HATCH_CONFIG

0
.evergreen/install-dependencies.sh Normal file → Executable file
View File

0
.evergreen/run-azurekms-fail-test.sh Normal file → Executable file
View File

0
.evergreen/run-azurekms-test.sh Normal file → Executable file
View File

0
.evergreen/run-deployed-lambda-aws-tests.sh Normal file → Executable file
View File

0
.evergreen/run-gcpkms-test.sh Normal file → Executable file
View File

0
.evergreen/run-perf-tests.sh Normal file → Executable file
View File

View File

@ -38,6 +38,7 @@ export PIP_PREFER_BINARY=1 # Prefer binary dists by default
set +x
python -c "import sys; sys.exit(sys.prefix == sys.base_prefix)" || (echo "Not inside a virtual env!"; exit 1)
PYTHON_IMPL=$(python -c "import platform; print(platform.python_implementation())")
# Try to source local Drivers Secrets
if [ -f ./secrets-export.sh ]; then
@ -47,19 +48,24 @@ else
echo "Not sourcing secrets"
fi
# Ensure C extensions have compiled.
if [ -z "${NO_EXT:-}" ] && [ "$PYTHON_IMPL" = "CPython" ]; then
python tools/fail_if_no_c.py
fi
if [ "$AUTH" != "noauth" ]; then
if [ ! -z "$TEST_DATA_LAKE" ]; then
if [ -n "$TEST_DATA_LAKE" ]; then
export DB_USER="mhuser"
export DB_PASSWORD="pencil"
elif [ ! -z "$TEST_SERVERLESS" ]; then
source ${DRIVERS_TOOLS}/.evergreen/serverless/secrets-export.sh
elif [ -n "$TEST_SERVERLESS" ]; then
source "${DRIVERS_TOOLS}"/.evergreen/serverless/secrets-export.sh
export DB_USER=$SERVERLESS_ATLAS_USER
export DB_PASSWORD=$SERVERLESS_ATLAS_PASSWORD
export MONGODB_URI="$SERVERLESS_URI"
echo "MONGODB_URI=$MONGODB_URI"
export SINGLE_MONGOS_LB_URI=$MONGODB_URI
export MULTI_MONGOS_LB_URI=$MONGODB_URI
elif [ ! -z "$TEST_AUTH_OIDC" ]; then
elif [ -n "$TEST_AUTH_OIDC" ]; then
export DB_USER=$OIDC_ADMIN_USER
export DB_PASSWORD=$OIDC_ADMIN_PWD
export DB_IP="$MONGODB_URI"
@ -240,7 +246,6 @@ python -c 'import sys; print(sys.version)'
# Run the tests with coverage if requested and coverage is installed.
# Only cover CPython. PyPy reports suspiciously low coverage.
PYTHON_IMPL=$(python -c "import platform; print(platform.python_implementation())")
if [ -n "$COVERAGE" ] && [ "$PYTHON_IMPL" = "CPython" ]; then
# Keep in sync with combine-coverage.sh.
# coverage >=5 is needed for relative_files=true.

View File

@ -0,0 +1,8 @@
#!/bin/bash
set -o xtrace
mkdir out_dir
# shellcheck disable=SC2156
find "$MONGO_ORCHESTRATION_HOME" -name \*.log -exec sh -c 'x="{}"; mv $x $PWD/out_dir/$(basename $(dirname $x))_$(basename $x)' \;
tar zcvf mongodb-logs.tar.gz -C out_dir/ .
rm -rf out_dir

View File

@ -0,0 +1,46 @@
#!/bin/bash
set -o xtrace
# Enable core dumps if enabled on the machine
# Copied from https://github.com/mongodb/mongo/blob/master/etc/evergreen.yml
if [ -f /proc/self/coredump_filter ]; then
# Set the shell process (and its children processes) to dump ELF headers (bit 4),
# anonymous shared mappings (bit 1), and anonymous private mappings (bit 0).
echo 0x13 >/proc/self/coredump_filter
if [ -f /sbin/sysctl ]; then
# Check that the core pattern is set explicitly on our distro image instead
# of being the OS's default value. This ensures that coredump names are consistent
# across distros and can be picked up by Evergreen.
core_pattern=$(/sbin/sysctl -n "kernel.core_pattern")
if [ "$core_pattern" = "dump_%e.%p.core" ]; then
echo "Enabling coredumps"
ulimit -c unlimited
fi
fi
fi
if [ "$(uname -s)" = "Darwin" ]; then
core_pattern_mac=$(/usr/sbin/sysctl -n "kern.corefile")
if [ "$core_pattern_mac" = "dump_%N.%P.core" ]; then
echo "Enabling coredumps"
ulimit -c unlimited
fi
fi
if [ -n "${skip_crypt_shared}" ]; then
export SKIP_CRYPT_SHARED=1
fi
MONGODB_VERSION=${VERSION} \
TOPOLOGY=${TOPOLOGY} \
AUTH=${AUTH:-noauth} \
SSL=${SSL:-nossl} \
STORAGE_ENGINE=${STORAGE_ENGINE:-} \
DISABLE_TEST_COMMANDS=${DISABLE_TEST_COMMANDS:-} \
ORCHESTRATION_FILE=${ORCHESTRATION_FILE:-} \
REQUIRE_API_VERSION=${REQUIRE_API_VERSION:-} \
LOAD_BALANCER=${LOAD_BALANCER:-} \
bash ${DRIVERS_TOOLS}/.evergreen/run-orchestration.sh
# run-orchestration generates expansion file with the MONGODB_URI for the cluster

View File

@ -0,0 +1,7 @@
#!/bin/bash
. .evergreen/scripts/env.sh
set -x
export BASE_SHA="$1"
export HEAD_SHA="$2"
bash .evergreen/run-import-time-test.sh

7
.evergreen/scripts/cleanup.sh Executable file
View File

@ -0,0 +1,7 @@
#!/bin/bash
if [ -f "$DRIVERS_TOOLS"/.evergreen/csfle/secrets-export.sh ]; then
. .evergreen/hatch.sh encryption:teardown
fi
rm -rf "${DRIVERS_TOOLS}" || true
rm -f ./secrets-export.sh || true

19
.evergreen/scripts/configure-env.sh Normal file → Executable file
View File

@ -1,4 +1,4 @@
#!/bin/bash -ex
#!/bin/bash -eux
# Get the current unique version of this checkout
# shellcheck disable=SC2154
@ -29,7 +29,7 @@ fi
export MONGO_ORCHESTRATION_HOME="$DRIVERS_TOOLS/.evergreen/orchestration"
export MONGODB_BINARIES="$DRIVERS_TOOLS/mongodb/bin"
cat <<EOT > $SCRIPT_DIR/env.sh
cat <<EOT > "$SCRIPT_DIR"/env.sh
set -o errexit
export PROJECT_DIRECTORY="$PROJECT_DIRECTORY"
export CURRENT_VERSION="$CURRENT_VERSION"
@ -38,6 +38,21 @@ export DRIVERS_TOOLS="$DRIVERS_TOOLS"
export MONGO_ORCHESTRATION_HOME="$MONGO_ORCHESTRATION_HOME"
export MONGODB_BINARIES="$MONGODB_BINARIES"
export PROJECT_DIRECTORY="$PROJECT_DIRECTORY"
export SETDEFAULTENCODING="${SETDEFAULTENCODING:-}"
export SKIP_CSOT_TESTS="${SKIP_CSOT_TESTS:-}"
export MONGODB_STARTED="${MONGODB_STARTED:-}"
export DISABLE_TEST_COMMANDS="${DISABLE_TEST_COMMANDS:-}"
export GREEN_FRAMEWORK="${GREEN_FRAMEWORK:-}"
export NO_EXT="${NO_EXT:-}"
export COVERAGE="${COVERAGE:-}"
export COMPRESSORS="${COMPRESSORS:-}"
export MONGODB_API_VERSION="${MONGODB_API_VERSION:-}"
export SKIP_HATCH="${SKIP_HATCH:-}"
export skip_crypt_shared="${skip_crypt_shared:-}"
export STORAGE_ENGINE="${STORAGE_ENGINE:-}"
export REQUIRE_API_VERSION="${REQUIRE_API_VERSION:-}"
export skip_web_identity_auth_test="${skip_web_identity_auth_test:-}"
export skip_ECS_auth_test="${skip_ECS_auth_test:-}"
export TMPDIR="$MONGO_ORCHESTRATION_HOME/db"
export PATH="$MONGODB_BINARIES:$PATH"

View File

@ -0,0 +1,4 @@
#!/bin/bash
# Download all the task coverage files.
aws s3 cp --recursive s3://"$1"/coverage/"$2"/"$3"/coverage/ coverage/

View File

@ -0,0 +1,8 @@
#!/bin/bash
set +x
. src/.evergreen/scripts/env.sh
# shellcheck disable=SC2044
for filename in $(find $DRIVERS_TOOLS -name \*.json); do
perl -p -i -e "s|ABSOLUTE_PATH_REPLACEMENT_TOKEN|$DRIVERS_TOOLS|g" $filename
done

View File

@ -0,0 +1,5 @@
#!/bin/bash
set +x
. src/.evergreen/scripts/env.sh
echo '{"results": [{ "status": "FAIL", "test_file": "Build", "log_raw": "No test-results.json found was created" } ]}' >$PROJECT_DIRECTORY/test-results.json

View File

@ -0,0 +1,6 @@
#!/bin/bash
set -o xtrace
file="$PROJECT_DIRECTORY/.evergreen/install-dependencies.sh"
# Don't use ${file} syntax here because evergreen treats it as an empty expansion.
[ -f "$file" ] && bash "$file" || echo "$file not available, skipping"

View File

@ -0,0 +1,8 @@
#!/bin/bash
set +x
. src/.evergreen/scripts/env.sh
# shellcheck disable=SC2044
for i in $(find "$DRIVERS_TOOLS"/.evergreen "$PROJECT_DIRECTORY"/.evergreen -name \*.sh); do
chmod +x "$i"
done

View File

@ -0,0 +1,12 @@
#!/bin/bash
. src/.evergreen/scripts/env.sh
set -o xtrace
rm -rf $DRIVERS_TOOLS
if [ "$PROJECT" = "drivers-tools" ]; then
# If this was a patch build, doing a fresh clone would not actually test the patch
cp -R $PROJECT_DIRECTORY/ $DRIVERS_TOOLS
else
git clone https://github.com/mongodb-labs/drivers-evergreen-tools.git $DRIVERS_TOOLS
fi
echo "{ \"releases\": { \"default\": \"$MONGODB_BINARIES\" }}" >$MONGO_ORCHESTRATION_HOME/orchestration.config

View File

@ -0,0 +1,7 @@
#!/bin/bash
# Disable xtrace for security reasons (just in case it was accidentally set).
set +x
set -o errexit
bash "${DRIVERS_TOOLS}"/.evergreen/auth_aws/setup_secrets.sh drivers/atlas_connect
TEST_ATLAS=1 bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-eg

View File

@ -0,0 +1,15 @@
#!/bin/bash
# shellcheck disable=SC2154
if [ "${skip_ECS_auth_test}" = "true" ]; then
echo "This platform does not support the ECS auth test, skipping..."
exit 0
fi
set -ex
cd "$DRIVERS_TOOLS"/.evergreen/auth_aws
. ./activate-authawsvenv.sh
. aws_setup.sh ecs
export MONGODB_BINARIES="$MONGODB_BINARIES"
export PROJECT_DIRECTORY="$PROJECT_DIRECTORY"
python aws_tester.py ecs
cd -

View File

@ -0,0 +1,4 @@
#!/bin/bash
set -o xtrace
PYTHON_BINARY=${PYTHON_BINARY} bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh doctest:test

View File

@ -0,0 +1,6 @@
#!/bin/bash
# Disable xtrace for security reasons (just in case it was accidentally set).
set +x
bash "${DRIVERS_TOOLS}"/.evergreen/auth_aws/setup_secrets.sh drivers/enterprise_auth
TEST_ENTERPRISE_AUTH=1 AUTH=auth bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-eg

View File

@ -0,0 +1,7 @@
#!/bin/bash
. .evergreen/scripts/env.sh
export PYTHON_BINARY=/opt/mongodbtoolchain/v4/bin/python3
export LIBMONGOCRYPT_URL=https://s3.amazonaws.com/mciuploads/libmongocrypt/debian11/master/latest/libmongocrypt.tar.gz
SKIP_SERVERS=1 bash ./.evergreen/setup-encryption.sh
SUCCESS=false TEST_FLE_GCP_AUTO=1 ./.evergreen/hatch.sh test:test-eg

View File

@ -0,0 +1,22 @@
#!/bin/bash
set -o xtrace
. ${DRIVERS_TOOLS}/.evergreen/download-mongodb.sh || true
get_distro || true
echo $DISTRO
echo $MARCH
echo $OS
uname -a || true
ls /etc/*release* || true
cc --version || true
gcc --version || true
clang --version || true
gcov --version || true
lcov --version || true
llvm-cov --version || true
echo $PATH
ls -la /usr/local/Cellar/llvm/*/bin/ || true
ls -la /usr/local/Cellar/ || true
scan-build --version || true
genhtml --version || true
valgrind --version || true

View File

@ -0,0 +1,3 @@
#!/bin/bash
MONGODB_URI=${MONGODB_URI} bash "${DRIVERS_TOOLS}"/.evergreen/run-load-balancer.sh start

View File

@ -0,0 +1,5 @@
#!/bin/bash
set -o xtrace
export PYTHON_BINARY=${PYTHON_BINARY}
bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-mockupdb

View File

@ -28,7 +28,7 @@ export MOD_WSGI_SO=/opt/python/mod_wsgi/python_version/$PYTHON_VERSION/mod_wsgi_
export PYTHONHOME=/opt/python/$PYTHON_VERSION
# If MOD_WSGI_EMBEDDED is set use the default embedded mode behavior instead
# of daemon mode (WSGIDaemonProcess).
if [ -n "$MOD_WSGI_EMBEDDED" ]; then
if [ -n "${MOD_WSGI_EMBEDDED:-}" ]; then
export MOD_WSGI_CONF=mod_wsgi_test_embedded.conf
else
export MOD_WSGI_CONF=mod_wsgi_test.conf

View File

@ -13,10 +13,16 @@ set -o errexit # Exit the script with error if any of the commands fail
# mechanism.
# PYTHON_BINARY The Python version to use.
echo "Running MONGODB-AWS authentication tests"
# shellcheck disable=SC2154
if [ "${skip_EC2_auth_test:-}" = "true" ] && { [ "$1" = "ec2" ] || [ "$1" = "web-identity" ]; }; then
echo "This platform does not support the EC2 auth test, skipping..."
exit 0
fi
echo "Running MONGODB-AWS authentication tests for $1"
# Handle credentials and environment setup.
. $DRIVERS_TOOLS/.evergreen/auth_aws/aws_setup.sh $1
. "$DRIVERS_TOOLS"/.evergreen/auth_aws/aws_setup.sh "$1"
# show test output
set -x

View File

@ -0,0 +1,8 @@
#!/bin/bash
TEST_OCSP=1 \
PYTHON_BINARY="${PYTHON_BINARY}" \
CA_FILE="${DRIVERS_TOOLS}/.evergreen/ocsp/${OCSP_ALGORITHM}/ca.pem" \
OCSP_TLS_SHOULD_SUCCEED="${OCSP_TLS_SHOULD_SUCCEED}" \
bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-eg
bash "${DRIVERS_TOOLS}"/.evergreen/ocsp/teardown.sh

View File

@ -0,0 +1,4 @@
#!/bin/bash
PROJECT_DIRECTORY=${PROJECT_DIRECTORY}
bash "${PROJECT_DIRECTORY}"/.evergreen/run-perf-tests.sh

55
.evergreen/scripts/run-tests.sh Executable file
View File

@ -0,0 +1,55 @@
#!/bin/bash
# Disable xtrace
set +x
if [ -n "${MONGODB_STARTED}" ]; then
export PYMONGO_MUST_CONNECT=true
fi
if [ -n "${DISABLE_TEST_COMMANDS}" ]; then
export PYMONGO_DISABLE_TEST_COMMANDS=1
fi
if [ -n "${test_encryption}" ]; then
# Disable xtrace (just in case it was accidentally set).
set +x
bash "${DRIVERS_TOOLS}"/.evergreen/csfle/await-servers.sh
export TEST_ENCRYPTION=1
if [ -n "${test_encryption_pyopenssl}" ]; then
export TEST_ENCRYPTION_PYOPENSSL=1
fi
fi
if [ -n "${test_crypt_shared}" ]; then
export TEST_CRYPT_SHARED=1
export CRYPT_SHARED_LIB_PATH=${CRYPT_SHARED_LIB_PATH}
fi
if [ -n "${test_pyopenssl}" ]; then
export TEST_PYOPENSSL=1
fi
if [ -n "${SETDEFAULTENCODING}" ]; then
export SETDEFAULTENCODING="${SETDEFAULTENCODING}"
fi
if [ -n "${test_loadbalancer}" ]; then
export TEST_LOADBALANCER=1
export SINGLE_MONGOS_LB_URI="${SINGLE_MONGOS_LB_URI}"
export MULTI_MONGOS_LB_URI="${MULTI_MONGOS_LB_URI}"
fi
if [ -n "${test_serverless}" ]; then
export TEST_SERVERLESS=1
fi
if [ -n "${TEST_INDEX_MANAGEMENT:-}" ]; then
export TEST_INDEX_MANAGEMENT=1
fi
if [ -n "${SKIP_CSOT_TESTS}" ]; then
export SKIP_CSOT_TESTS=1
fi
GREEN_FRAMEWORK=${GREEN_FRAMEWORK} \
PYTHON_BINARY=${PYTHON_BINARY} \
NO_EXT=${NO_EXT} \
COVERAGE=${COVERAGE} \
COMPRESSORS=${COMPRESSORS} \
AUTH=${AUTH} \
SSL=${SSL} \
TEST_DATA_LAKE=${TEST_DATA_LAKE:-} \
TEST_SUITES=${TEST_SUITES:-} \
MONGODB_API_VERSION=${MONGODB_API_VERSION} \
SKIP_HATCH=${SKIP_HATCH} \
bash "${PROJECT_DIRECTORY}"/.evergreen/hatch.sh test:test-eg

View File

@ -0,0 +1,21 @@
#!/bin/bash -eu
# Example use: bash run-with-env.sh run-tests.sh {args...}
# Parameter expansion to get just the current directory's name
if [ "${PWD##*/}" == "src" ]; then
. .evergreen/scripts/env.sh
if [ -f ".evergreen/scripts/test-env.sh" ]; then
. .evergreen/scripts/test-env.sh
fi
else
. src/.evergreen/scripts/env.sh
if [ -f "src/.evergreen/scripts/test-env.sh" ]; then
. src/.evergreen/scripts/test-env.sh
fi
fi
set -eu
# shellcheck source=/dev/null
. "$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash
if [ -n "${test_encryption}" ]; then
./.evergreen/hatch.sh encryption:setup
fi

View File

@ -0,0 +1,27 @@
#!/bin/bash -eux
PROJECT_DIRECTORY="$(pwd)"
SCRIPT_DIR="$PROJECT_DIRECTORY/.evergreen/scripts"
if [ -f "$SCRIPT_DIR/test-env.sh" ]; then
echo "Reading $SCRIPT_DIR/test-env.sh file"
. "$SCRIPT_DIR/test-env.sh"
exit 0
fi
cat <<EOT > "$SCRIPT_DIR"/test-env.sh
export test_encryption="${test_encryption:-}"
export test_encryption_pyopenssl="${test_encryption_pyopenssl:-}"
export test_crypt_shared="${test_crypt_shared:-}"
export test_pyopenssl="${test_pyopenssl:-}"
export test_loadbalancer="${test_loadbalancer:-}"
export test_serverless="${test_serverless:-}"
export TEST_INDEX_MANAGEMENT="${TEST_INDEX_MANAGEMENT:-}"
export TEST_DATA_LAKE="${TEST_DATA_LAKE:-}"
export ORCHESTRATION_FILE="${ORCHESTRATION_FILE:-}"
export AUTH="${AUTH:-noauth}"
export SSL="${SSL:-nossl}"
export PYTHON_BINARY="${PYTHON_BINARY:-}"
EOT
chmod +x "$SCRIPT_DIR"/test-env.sh

View File

@ -0,0 +1,5 @@
#!/bin/bash
cd "${DRIVERS_TOOLS}"/.evergreen || exit
DRIVERS_TOOLS=${DRIVERS_TOOLS}
bash "${DRIVERS_TOOLS}"/.evergreen/run-load-balancer.sh stop

View File

@ -0,0 +1,3 @@
#!/bin/bash
aws s3 cp htmlcov/ s3://"$1"/coverage/"$2"/"$3"/htmlcov/ --recursive --acl public-read --region us-east-1

View File

@ -0,0 +1,11 @@
#!/bin/bash
set +x
. src/.evergreen/scripts/env.sh
# shellcheck disable=SC2044
for i in $(find "$DRIVERS_TOOLS"/.evergreen "$PROJECT_DIRECTORY"/.evergreen -name \*.sh); do
< "$i" tr -d '\r' >"$i".new
mv "$i".new "$i"
done
# Copy client certificate because symlinks do not work on Windows.
cp "$DRIVERS_TOOLS"/.evergreen/x509gen/client.pem "$MONGO_ORCHESTRATION_HOME"/lib/client.pem

7
.evergreen/setup-encryption.sh Normal file → Executable file
View File

@ -52,6 +52,9 @@ ls -la libmongocrypt
ls -la libmongocrypt/nocrypto
if [ -z "${SKIP_SERVERS:-}" ]; then
bash ${DRIVERS_TOOLS}/.evergreen/csfle/setup-secrets.sh
bash ${DRIVERS_TOOLS}/.evergreen/csfle/start-servers.sh
PYTHON_BINARY_OLD=${PYTHON_BINARY}
export PYTHON_BINARY=""
bash "${DRIVERS_TOOLS}"/.evergreen/csfle/setup-secrets.sh
export PYTHON_BINARY=$PYTHON_BINARY_OLD
bash "${DRIVERS_TOOLS}"/.evergreen/csfle/start-servers.sh
fi

0
.evergreen/teardown-encryption.sh Normal file → Executable file
View File

View File

@ -17,7 +17,7 @@ find_python3() {
elif [ -d "/Library/Frameworks/Python.Framework/Versions/3.9" ]; then
PYTHON="/Library/Frameworks/Python.Framework/Versions/3.9/bin/python3"
fi
elif [ "Windows_NT" = "$OS" ]; then # Magic variable in cygwin
elif [ "Windows_NT" = "${OS:-}" ]; then # Magic variable in cygwin
PYTHON="C:/python/Python39/python.exe"
else
# Prefer our own toolchain, fall back to mongodb toolchain if it has Python 3.9+.
@ -56,7 +56,7 @@ createvirtualenv () {
# Workaround for bug in older versions of virtualenv.
$VIRTUALENV $VENVPATH 2>/dev/null || $VIRTUALENV $VENVPATH
fi
if [ "Windows_NT" = "$OS" ]; then
if [ "Windows_NT" = "${OS:-}" ]; then
# Workaround https://bugs.python.org/issue32451:
# mongovenv/Scripts/activate: line 3: $'\r': command not found
dos2unix $VENVPATH/Scripts/activate || true
@ -78,6 +78,7 @@ testinstall () {
PYTHON=$1
RELEASE=$2
NO_VIRTUALENV=$3
PYTHON_IMPL=$(python -c "import platform; print(platform.python_implementation())")
if [ -z "$NO_VIRTUALENV" ]; then
createvirtualenv $PYTHON venvtestinstall
@ -86,7 +87,11 @@ testinstall () {
$PYTHON -m pip install --upgrade $RELEASE
cd tools
$PYTHON fail_if_no_c.py
if [ "$PYTHON_IMPL" = "CPython" ]; then
$PYTHON fail_if_no_c.py
fi
$PYTHON -m pip uninstall -y pymongo
cd ..

View File

@ -102,3 +102,15 @@ repos:
# - test/versioned-api/crud-api-version-1-strict.json:514: nin ==> inn, min, bin, nine
# - test/test_client.py:188: te ==> the, be, we, to
args: ["-L", "fle,fo,infinit,isnt,nin,te,aks"]
- repo: local
hooks:
- id: executable-shell
name: executable-shell
entry: chmod +x
language: system
types: [shell]
exclude: |
(?x)(
.evergreen/retry-with-backoff.sh
)

View File

@ -38,3 +38,61 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
2) License Notice for _asyncio_lock.py
-----------------------------------------
1. This LICENSE AGREEMENT is between the Python Software Foundation
("PSF"), and the Individual or Organization ("Licensee") accessing and
otherwise using this software ("Python") in source or binary form and
its associated documentation.
2. Subject to the terms and conditions of this License Agreement, PSF hereby
grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
analyze, test, perform and/or display publicly, prepare derivative works,
distribute, and otherwise use Python alone or in any derivative version,
provided, however, that PSF's License Agreement and PSF's notice of copyright,
i.e., "Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved"
are retained in Python alone or in any derivative version prepared by Licensee.
3. In the event Licensee prepares a derivative work that is based on
or incorporates Python or any part thereof, and wants to make
the derivative work available to others as provided herein, then
Licensee hereby agrees to include in any such work a brief summary of
the changes made to Python.
4. PSF is making Python available to Licensee on an "AS IS"
basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
INFRINGE ANY THIRD PARTY RIGHTS.
5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
6. This License Agreement will automatically terminate upon a material
breach of its terms and conditions.
7. Nothing in this License Agreement shall be deemed to create any
relationship of agency, partnership, or joint venture between PSF and
Licensee. This License Agreement does not grant permission to use PSF
trademarks or trade name in a trademark sense to endorse or promote
products or services of Licensee, or any third party.
8. By copying, installing or otherwise using Python, Licensee
agrees to be bound by the terms and conditions of this License
Agreement.
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.

309
pymongo/_asyncio_lock.py Normal file
View File

@ -0,0 +1,309 @@
# Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved
"""Lock and Condition classes vendored from https://github.com/python/cpython/blob/main/Lib/asyncio/locks.py
to port 3.13 fixes to older versions of Python.
Can be removed once we drop Python 3.12 support."""
from __future__ import annotations
import collections
import threading
from asyncio import events, exceptions
from typing import Any, Coroutine, Optional
_global_lock = threading.Lock()
class _LoopBoundMixin:
_loop = None
def _get_loop(self) -> Any:
loop = events._get_running_loop()
if self._loop is None:
with _global_lock:
if self._loop is None:
self._loop = loop
if loop is not self._loop:
raise RuntimeError(f"{self!r} is bound to a different event loop")
return loop
class _ContextManagerMixin:
async def __aenter__(self) -> None:
await self.acquire() # type: ignore[attr-defined]
# We have no use for the "as ..." clause in the with
# statement for locks.
return
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release() # type: ignore[attr-defined]
class Lock(_ContextManagerMixin, _LoopBoundMixin):
"""Primitive lock objects.
A primitive lock is a synchronization primitive that is not owned
by a particular task when locked. A primitive lock is in one
of two states, 'locked' or 'unlocked'.
It is created in the unlocked state. It has two basic methods,
acquire() and release(). When the state is unlocked, acquire()
changes the state to locked and returns immediately. When the
state is locked, acquire() blocks until a call to release() in
another task changes it to unlocked, then the acquire() call
resets it to locked and returns. The release() method should only
be called in the locked state; it changes the state to unlocked
and returns immediately. If an attempt is made to release an
unlocked lock, a RuntimeError will be raised.
When more than one task is blocked in acquire() waiting for
the state to turn to unlocked, only one task proceeds when a
release() call resets the state to unlocked; successive release()
calls will unblock tasks in FIFO order.
Locks also support the asynchronous context management protocol.
'async with lock' statement should be used.
Usage:
lock = Lock()
...
await lock.acquire()
try:
...
finally:
lock.release()
Context manager usage:
lock = Lock()
...
async with lock:
...
Lock objects can be tested for locking state:
if not lock.locked():
await lock.acquire()
else:
# lock is acquired
...
"""
def __init__(self) -> None:
self._waiters: Optional[collections.deque] = None
self._locked = False
def __repr__(self) -> str:
res = super().__repr__()
extra = "locked" if self._locked else "unlocked"
if self._waiters:
extra = f"{extra}, waiters:{len(self._waiters)}"
return f"<{res[1:-1]} [{extra}]>"
def locked(self) -> bool:
"""Return True if lock is acquired."""
return self._locked
async def acquire(self) -> bool:
"""Acquire a lock.
This method blocks until the lock is unlocked, then sets it to
locked and returns True.
"""
# Implement fair scheduling, where thread always waits
# its turn. Jumping the queue if all are cancelled is an optimization.
if not self._locked and (
self._waiters is None or all(w.cancelled() for w in self._waiters)
):
self._locked = True
return True
if self._waiters is None:
self._waiters = collections.deque()
fut = self._get_loop().create_future()
self._waiters.append(fut)
try:
try:
await fut
finally:
self._waiters.remove(fut)
except exceptions.CancelledError:
# Currently the only exception designed be able to occur here.
# Ensure the lock invariant: If lock is not claimed (or about
# to be claimed by us) and there is a Task in waiters,
# ensure that the Task at the head will run.
if not self._locked:
self._wake_up_first()
raise
# assert self._locked is False
self._locked = True
return True
def release(self) -> None:
"""Release a lock.
When the lock is locked, reset it to unlocked, and return.
If any other tasks are blocked waiting for the lock to become
unlocked, allow exactly one of them to proceed.
When invoked on an unlocked lock, a RuntimeError is raised.
There is no return value.
"""
if self._locked:
self._locked = False
self._wake_up_first()
else:
raise RuntimeError("Lock is not acquired.")
def _wake_up_first(self) -> None:
"""Ensure that the first waiter will wake up."""
if not self._waiters:
return
try:
fut = next(iter(self._waiters))
except StopIteration:
return
# .done() means that the waiter is already set to wake up.
if not fut.done():
fut.set_result(True)
class Condition(_ContextManagerMixin, _LoopBoundMixin):
"""Asynchronous equivalent to threading.Condition.
This class implements condition variable objects. A condition variable
allows one or more tasks to wait until they are notified by another
task.
A new Lock object is created and used as the underlying lock.
"""
def __init__(self, lock: Optional[Lock] = None) -> None:
if lock is None:
lock = Lock()
self._lock = lock
# Export the lock's locked(), acquire() and release() methods.
self.locked = lock.locked
self.acquire = lock.acquire
self.release = lock.release
self._waiters: collections.deque = collections.deque()
def __repr__(self) -> str:
res = super().__repr__()
extra = "locked" if self.locked() else "unlocked"
if self._waiters:
extra = f"{extra}, waiters:{len(self._waiters)}"
return f"<{res[1:-1]} [{extra}]>"
async def wait(self) -> bool:
"""Wait until notified.
If the calling task has not acquired the lock when this
method is called, a RuntimeError is raised.
This method releases the underlying lock, and then blocks
until it is awakened by a notify() or notify_all() call for
the same condition variable in another task. Once
awakened, it re-acquires the lock and returns True.
This method may return spuriously,
which is why the caller should always
re-check the state and be prepared to wait() again.
"""
if not self.locked():
raise RuntimeError("cannot wait on un-acquired lock")
fut = self._get_loop().create_future()
self.release()
try:
try:
self._waiters.append(fut)
try:
await fut
return True
finally:
self._waiters.remove(fut)
finally:
# Must re-acquire lock even if wait is cancelled.
# We only catch CancelledError here, since we don't want any
# other (fatal) errors with the future to cause us to spin.
err = None
while True:
try:
await self.acquire()
break
except exceptions.CancelledError as e:
err = e
if err is not None:
try:
raise err # Re-raise most recent exception instance.
finally:
err = None # Break reference cycles.
except BaseException:
# Any error raised out of here _may_ have occurred after this Task
# believed to have been successfully notified.
# Make sure to notify another Task instead. This may result
# in a "spurious wakeup", which is allowed as part of the
# Condition Variable protocol.
self._notify(1)
raise
async def wait_for(self, predicate: Any) -> Coroutine:
"""Wait until a predicate becomes true.
The predicate should be a callable whose result will be
interpreted as a boolean value. The method will repeatedly
wait() until it evaluates to true. The final predicate value is
the return value.
"""
result = predicate()
while not result:
await self.wait()
result = predicate()
return result
def notify(self, n: int = 1) -> None:
"""By default, wake up one task waiting on this condition, if any.
If the calling task has not acquired the lock when this method
is called, a RuntimeError is raised.
This method wakes up n of the tasks waiting for the condition
variable; if fewer than n are waiting, they are all awoken.
Note: an awakened task does not actually return from its
wait() call until it can reacquire the lock. Since notify() does
not release the lock, its caller should.
"""
if not self.locked():
raise RuntimeError("cannot notify on un-acquired lock")
self._notify(n)
def _notify(self, n: int) -> None:
idx = 0
for fut in self._waiters:
if idx >= n:
break
if not fut.done():
idx += 1
fut.set_result(False)
def notify_all(self) -> None:
"""Wake up all tasks waiting on this condition. This method acts
like notify(), but wakes up all waiting tasks instead of one. If the
calling task has not acquired the lock when this method is called,
a RuntimeError is raised.
"""
self.notify(len(self._waiters))

49
pymongo/_asyncio_task.py Normal file
View File

@ -0,0 +1,49 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A custom asyncio.Task that allows checking if a task has been sent a cancellation request.
Can be removed once we drop Python 3.10 support in favor of asyncio.Task.cancelling."""
from __future__ import annotations
import asyncio
import sys
from typing import Any, Coroutine, Optional
# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered
class _Task(asyncio.Task):
def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None:
super().__init__(coro, name=name)
self._cancel_requests = 0
asyncio._register_task(self)
def cancel(self, msg: Optional[str] = None) -> bool:
self._cancel_requests += 1
return super().cancel(msg=msg)
def uncancel(self) -> int:
if self._cancel_requests > 0:
self._cancel_requests -= 1
return self._cancel_requests
def cancelling(self) -> int:
return self._cancel_requests
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task:
if sys.version_info >= (3, 11):
return asyncio.create_task(coro, name=name)
return _Task(coro, name=name)

View File

@ -476,7 +476,6 @@ class _AsyncClientBulk:
if op_type == "delete":
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
full_result[f"{op_type}Results"][original_index] = res
except Exception as exc:
# Attempt to close the cursor, then raise top-level error.
if cmd_cursor.alive:

View File

@ -45,7 +45,7 @@ from pymongo.common import (
)
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.lock import _ALock, _create_lock
from pymongo.lock import _async_create_lock
from pymongo.message import (
_CursorAddress,
_GetMore,
@ -77,7 +77,7 @@ class _ConnectionManager:
def __init__(self, conn: AsyncConnection, more_to_come: bool):
self.conn: Optional[AsyncConnection] = conn
self.more_to_come = more_to_come
self._alock = _ALock(_create_lock())
self._lock = _async_create_lock()
def update_exhaust(self, more_to_come: bool) -> None:
self.more_to_come = more_to_come
@ -1299,7 +1299,7 @@ class AsyncCursor(Generic[_DocumentType]):
>>> await cursor.to_list()
Or, so read at most n items from the cursor::
Or, to read at most n items from the cursor::
>>> await cursor.to_list(n)

View File

@ -15,6 +15,7 @@
"""Support for explicit client-side field level encryption."""
from __future__ import annotations
import asyncio
import contextlib
import enum
import socket
@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
# BSON encoding/decoding errors are unrelated to encryption so
# we should propagate them unchanged.
raise
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptionError(exc) from exc
@ -200,6 +203,8 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except asyncio.CancelledError:
raise
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
@ -722,6 +727,8 @@ class AsyncClientEncryption(Generic[_DocumentType]):
await database.create_collection(name=name, **kwargs),
encrypted_fields,
)
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptedCollectionError(exc, encrypted_fields) from exc

View File

@ -32,6 +32,7 @@ access:
"""
from __future__ import annotations
import asyncio
import contextlib
import os
import warnings
@ -59,8 +60,8 @@ from typing import (
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
from bson.timestamp import Timestamp
from pymongo import _csot, common, helpers_shared, uri_parser
from pymongo.asynchronous import client_session, database, periodic_executor
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
from pymongo.asynchronous import client_session, database
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _EmptyServerSession
@ -82,7 +83,11 @@ from pymongo.errors import (
WaitQueueTimeoutError,
WriteConcernError,
)
from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks
from pymongo.lock import (
_HAS_REGISTER_AT_FORK,
_async_create_lock,
_release_locks,
)
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
from pymongo.message import _CursorAddress, _GetMore, _Query
from pymongo.monitoring import ConnectionClosedReason
@ -842,7 +847,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
self._default_database_name = dbase
self._lock = _ALock(_create_lock())
self._lock = _async_create_lock()
self._kill_cursors_queue: list = []
self._event_listeners = options.pool_options._event_listeners
@ -908,7 +913,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
await AsyncMongoClient._process_periodic_tasks(client)
return True
executor = periodic_executor.PeriodicExecutor(
executor = periodic_executor.AsyncPeriodicExecutor(
interval=common.KILL_CURSOR_FREQUENCY,
min_interval=common.MIN_HEARTBEAT_INTERVAL,
target=target,
@ -1722,7 +1727,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address=address,
)
async with operation.conn_mgr._alock:
async with operation.conn_mgr._lock:
async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
err_handler.contribute_socket(operation.conn_mgr.conn)
return await server.run_operation(
@ -1970,7 +1975,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
try:
if conn_mgr:
async with conn_mgr._alock:
async with conn_mgr._lock:
# Cursor is pinned to LB outside of a transaction.
assert address is not None
assert conn_mgr.conn is not None
@ -2033,6 +2038,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_id, conn_mgr in pinned_cursors:
try:
await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
# Raise the exception when client is closed so that it
@ -2047,6 +2054,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_ids in address_to_cursor_ids.items():
try:
await self._kill_cursors(cursor_ids, address, topology, session=None)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
raise
@ -2061,6 +2070,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
try:
await self._process_kill_cursors()
await self._topology.update_pool()
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
return

View File

@ -16,20 +16,20 @@
from __future__ import annotations
import asyncio
import atexit
import logging
import time
import weakref
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
from pymongo import common
from pymongo import common, periodic_executor
from pymongo._csot import MovingMinimum
from pymongo.asynchronous import periodic_executor
from pymongo.asynchronous.periodic_executor import _shutdown_executors
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
from pymongo.hello import Hello
from pymongo.lock import _create_lock
from pymongo.lock import _async_create_lock
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
from pymongo.periodic_executor import _shutdown_executors
from pymongo.pool_options import _is_faas
from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
@ -76,7 +76,7 @@ class MonitorBase:
await monitor._run() # type:ignore[attr-defined]
return True
executor = periodic_executor.PeriodicExecutor(
executor = periodic_executor.AsyncPeriodicExecutor(
interval=interval, min_interval=min_interval, target=target, name=name
)
@ -112,9 +112,9 @@ class MonitorBase:
"""
self.gc_safe_close()
def join(self, timeout: Optional[int] = None) -> None:
async def join(self, timeout: Optional[int] = None) -> None:
"""Wait for the monitor to stop."""
self._executor.join(timeout)
await self._executor.join(timeout)
def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
@ -139,7 +139,7 @@ class Monitor(MonitorBase):
"""
super().__init__(
topology,
"pymongo_server_monitor_thread",
"pymongo_server_monitor_task",
topology_settings.heartbeat_frequency,
common.MIN_HEARTBEAT_INTERVAL,
)
@ -238,6 +238,9 @@ class Monitor(MonitorBase):
except ReferenceError:
# Topology was garbage-collected.
await self.close()
finally:
if self._executor._stopped:
await self._rtt_monitor.close()
async def _check_server(self) -> ServerDescription:
"""Call hello or read the next streaming response.
@ -252,8 +255,10 @@ class Monitor(MonitorBase):
except (OperationFailure, NotPrimaryError) as exc:
# Update max cluster time even when hello fails.
details = cast(Mapping[str, Any], exc.details)
self._topology.receive_cluster_time(details.get("$clusterTime"))
await self._topology.receive_cluster_time(details.get("$clusterTime"))
raise
except asyncio.CancelledError:
raise
except ReferenceError:
raise
except Exception as error:
@ -280,7 +285,7 @@ class Monitor(MonitorBase):
await self._reset_connection()
if isinstance(error, _OperationCancelled):
raise
self._rtt_monitor.reset()
await self._rtt_monitor.reset()
# Server type defaults to Unknown.
return ServerDescription(address, error=error)
@ -321,9 +326,9 @@ class Monitor(MonitorBase):
self._conn_id = conn.id
response, round_trip_time = await self._check_with_socket(conn)
if not response.awaitable:
self._rtt_monitor.add_sample(round_trip_time)
await self._rtt_monitor.add_sample(round_trip_time)
avg_rtt, min_rtt = self._rtt_monitor.get()
avg_rtt, min_rtt = await self._rtt_monitor.get()
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
if self._publish:
assert self._listeners is not None
@ -419,6 +424,8 @@ class SrvMonitor(MonitorBase):
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception
except asyncio.CancelledError:
raise
except Exception:
# As per the spec, upon encountering an error:
# - An error must not be raised
@ -439,7 +446,7 @@ class _RttMonitor(MonitorBase):
"""
super().__init__(
topology,
"pymongo_server_rtt_thread",
"pymongo_server_rtt_task",
topology_settings.heartbeat_frequency,
common.MIN_HEARTBEAT_INTERVAL,
)
@ -447,7 +454,7 @@ class _RttMonitor(MonitorBase):
self._pool = pool
self._moving_average = MovingAverage()
self._moving_min = MovingMinimum()
self._lock = _create_lock()
self._lock = _async_create_lock()
async def close(self) -> None:
self.gc_safe_close()
@ -455,20 +462,20 @@ class _RttMonitor(MonitorBase):
# thread has the socket checked out, it will be closed when checked in.
await self._pool.reset()
def add_sample(self, sample: float) -> None:
async def add_sample(self, sample: float) -> None:
"""Add a RTT sample."""
with self._lock:
async with self._lock:
self._moving_average.add_sample(sample)
self._moving_min.add_sample(sample)
def get(self) -> tuple[Optional[float], float]:
async def get(self) -> tuple[Optional[float], float]:
"""Get the calculated average, or None if no samples yet and the min."""
with self._lock:
async with self._lock:
return self._moving_average.get(), self._moving_min.get()
def reset(self) -> None:
async def reset(self) -> None:
"""Reset the average RTT."""
with self._lock:
async with self._lock:
self._moving_average.reset()
self._moving_min.reset()
@ -478,10 +485,12 @@ class _RttMonitor(MonitorBase):
# heartbeat protocol (MongoDB 4.4+).
# XXX: Skip check if the server is unknown?
rtt = await self._ping()
self.add_sample(rtt)
await self.add_sample(rtt)
except ReferenceError:
# Topology was garbage-collected.
await self.close()
except asyncio.CancelledError:
raise
except Exception:
await self._pool.reset()
@ -536,4 +545,5 @@ def _shutdown_resources() -> None:
shutdown()
atexit.register(_shutdown_resources)
if _IS_SYNC:
atexit.register(_shutdown_resources)

View File

@ -1,219 +0,0 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Run a target function on a background thread."""
from __future__ import annotations
import asyncio
import sys
import threading
import time
import weakref
from typing import Any, Optional
from pymongo.lock import _ALock, _create_lock
_IS_SYNC = False
class PeriodicExecutor:
def __init__(
self,
interval: float,
min_interval: float,
target: Any,
name: Optional[str] = None,
):
"""Run a target function periodically on a background thread.
If the target's return value is false, the executor stops.
:param interval: Seconds between calls to `target`.
:param min_interval: Minimum seconds between calls if `wake` is
called very often.
:param target: A function.
:param name: A name to give the underlying thread.
"""
# threading.Event and its internal condition variable are expensive
# in Python 2, see PYTHON-983. Use a boolean to know when to wake.
# The executor's design is constrained by several Python issues, see
# "periodic_executor.rst" in this repository.
self._event = False
self._interval = interval
self._min_interval = min_interval
self._target = target
self._stopped = False
self._thread: Optional[threading.Thread] = None
self._name = name
self._skip_sleep = False
self._thread_will_exit = False
self._lock = _ALock(_create_lock())
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
def _run_async(self) -> None:
# The default asyncio loop implementation on Windows
# has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240)
# We explicitly use a different loop implementation here to prevent that issue
if sys.platform == "win32":
loop = asyncio.SelectorEventLoop()
try:
loop.run_until_complete(self._run()) # type: ignore[func-returns-value]
finally:
loop.close()
else:
asyncio.run(self._run()) # type: ignore[func-returns-value]
def open(self) -> None:
"""Start. Multiple calls have no effect.
Not safe to call from multiple threads at once.
"""
with self._lock:
if self._thread_will_exit:
# If the background thread has read self._stopped as True
# there is a chance that it has not yet exited. The call to
# join should not block indefinitely because there is no
# other work done outside the while loop in self._run.
try:
assert self._thread is not None
self._thread.join()
except ReferenceError:
# Thread terminated.
pass
self._thread_will_exit = False
self._stopped = False
started: Any = False
try:
started = self._thread and self._thread.is_alive()
except ReferenceError:
# Thread terminated.
pass
if not started:
if _IS_SYNC:
thread = threading.Thread(target=self._run, name=self._name)
else:
thread = threading.Thread(target=self._run_async, name=self._name)
thread.daemon = True
self._thread = weakref.proxy(thread)
_register_executor(self)
# Mitigation to RuntimeError firing when thread starts on shutdown
# https://github.com/python/cpython/issues/114570
try:
thread.start()
except RuntimeError as e:
if "interpreter shutdown" in str(e) or sys.is_finalizing():
self._thread = None
return
raise
def close(self, dummy: Any = None) -> None:
"""Stop. To restart, call open().
The dummy parameter allows an executor's close method to be a weakref
callback; see monitor.py.
"""
self._stopped = True
def join(self, timeout: Optional[int] = None) -> None:
if self._thread is not None:
try:
self._thread.join(timeout)
except (ReferenceError, RuntimeError):
# Thread already terminated, or not yet started.
pass
def wake(self) -> None:
"""Execute the target function soon."""
self._event = True
def update_interval(self, new_interval: int) -> None:
self._interval = new_interval
def skip_sleep(self) -> None:
self._skip_sleep = True
async def _should_stop(self) -> bool:
async with self._lock:
if self._stopped:
self._thread_will_exit = True
return True
return False
async def _run(self) -> None:
while not await self._should_stop():
try:
if not await self._target():
self._stopped = True
break
except BaseException:
async with self._lock:
self._stopped = True
self._thread_will_exit = True
raise
if self._skip_sleep:
self._skip_sleep = False
else:
deadline = time.monotonic() + self._interval
while not self._stopped and time.monotonic() < deadline:
await asyncio.sleep(self._min_interval)
if self._event:
break # Early wake.
self._event = False
# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started,
# an executor is kept alive by a strong reference from its thread and perhaps
# from other objects. When the thread dies and all other referrers are freed,
# the executor is freed and removed from _EXECUTORS. If any threads are
# running when the interpreter begins to shut down, we try to halt and join
# them to avoid spurious errors.
_EXECUTORS = set()
def _register_executor(executor: PeriodicExecutor) -> None:
ref = weakref.ref(executor, _on_executor_deleted)
_EXECUTORS.add(ref)
def _on_executor_deleted(ref: weakref.ReferenceType[PeriodicExecutor]) -> None:
_EXECUTORS.remove(ref)
def _shutdown_executors() -> None:
if _EXECUTORS is None:
return
# Copy the set. Stopping threads has the side effect of removing executors.
executors = list(_EXECUTORS)
# First signal all executors to close...
for ref in executors:
executor = ref()
if executor:
executor.close()
# ...then try to join them.
for ref in executors:
executor = ref()
if executor:
executor.join(1)
executor = None

View File

@ -23,7 +23,6 @@ import os
import socket
import ssl
import sys
import threading
import time
import weakref
from typing import (
@ -65,7 +64,11 @@ from pymongo.errors import ( # type:ignore[attr-defined]
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.lock import _ACondition, _ALock, _create_lock
from pymongo.lock import (
_async_cond_wait,
_async_create_condition,
_async_create_lock,
)
from pymongo.logger import (
_CONNECTION_LOGGER,
_ConnectionStatusMessage,
@ -208,11 +211,6 @@ def _raise_connection_failure(
raise AutoReconnect(msg) from error
async def _cond_wait(condition: _ACondition, deadline: Optional[float]) -> bool:
timeout = deadline - time.monotonic() if deadline else None
return await condition.wait(timeout)
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
details = {}
timeout = _csot.get_timeout()
@ -706,6 +704,8 @@ class AsyncConnection:
# shutdown.
try:
self.conn.close()
except asyncio.CancelledError:
raise
except Exception: # noqa: S110
pass
@ -992,8 +992,8 @@ class Pool:
# from the right side.
self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
_lock = _create_lock()
self.lock = _ALock(_lock)
self.lock = _async_create_lock()
self._max_connecting_cond = _async_create_condition(self.lock)
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
@ -1019,7 +1019,7 @@ class Pool:
# The first portion of the wait queue.
# Enforces: maxPoolSize
# Also used for: clearing the wait queue
self.size_cond = _ACondition(threading.Condition(_lock))
self.size_cond = _async_create_condition(self.lock)
self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size:
@ -1027,7 +1027,7 @@ class Pool:
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = _ACondition(threading.Condition(_lock))
self._max_connecting_cond = _async_create_condition(self.lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._client_id = client_id
@ -1466,7 +1466,8 @@ class Pool:
async with self.size_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=True)
while not (self.requests < self.max_pool_size):
if not await _cond_wait(self.size_cond, deadline):
timeout = deadline - time.monotonic() if deadline else None
if not await _async_cond_wait(self.size_cond, timeout):
# Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition.
if self.requests < self.max_pool_size:
@ -1489,7 +1490,8 @@ class Pool:
async with self._max_connecting_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=False)
while not (self.conns or self._pending < self._max_connecting):
if not await _cond_wait(self._max_connecting_cond, deadline):
timeout = deadline - time.monotonic() if deadline else None
if not await _async_cond_wait(self._max_connecting_cond, timeout):
# Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition.
if self.conns or self._pending < self._max_connecting:

View File

@ -27,8 +27,7 @@ import weakref
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
from pymongo import _csot, common, helpers_shared
from pymongo.asynchronous import periodic_executor
from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.asynchronous.monitor import SrvMonitor
from pymongo.asynchronous.pool import Pool
@ -44,7 +43,11 @@ from pymongo.errors import (
WriteError,
)
from pymongo.hello import Hello
from pymongo.lock import _ACondition, _ALock, _create_lock
from pymongo.lock import (
_async_cond_wait,
_async_create_condition,
_async_create_lock,
)
from pymongo.logger import (
_SDAM_LOGGER,
_SERVER_SELECTION_LOGGER,
@ -170,9 +173,10 @@ class Topology:
self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False
self._closed = False
_lock = _create_lock()
self._lock = _ALock(_lock)
self._condition = _ACondition(self._settings.condition_class(_lock))
self._lock = _async_create_lock()
self._condition = _async_create_condition(
self._lock, self._settings.condition_class if _IS_SYNC else None
)
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None
@ -185,7 +189,7 @@ class Topology:
async def target() -> bool:
return process_events_queue(weak)
executor = periodic_executor.PeriodicExecutor(
executor = periodic_executor.AsyncPeriodicExecutor(
interval=common.EVENTS_QUEUE_FREQUENCY,
min_interval=common.MIN_HEARTBEAT_INTERVAL,
target=target,
@ -354,7 +358,7 @@ class Topology:
# change, or for a timeout. We won't miss any changes that
# came after our most recent apply_selector call, since we've
# held the lock until now.
await self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
await _async_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL)
self._description.check_compatible()
now = time.monotonic()
server_descriptions = self._description.apply_selector(
@ -654,7 +658,7 @@ class Topology:
"""Wake all monitors, wait for at least one to check its server."""
async with self._lock:
self._request_check_all()
await self._condition.wait(wait_time)
await _async_cond_wait(self._condition, wait_time)
def data_bearing_servers(self) -> list[ServerDescription]:
"""Return a list of all data-bearing servers.
@ -742,7 +746,7 @@ class Topology:
if self._publish_server or self._publish_tp:
# Make sure the events executor thread is fully closed before publishing the remaining events
self.__events_executor.close()
self.__events_executor.join(1)
await self.__events_executor.join(1)
process_events_queue(weakref.ref(self._events)) # type: ignore[arg-type]
@property

View File

@ -11,15 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Internal helpers for lock and condition coordination primitives."""
from __future__ import annotations
import asyncio
import collections
import os
import sys
import threading
import time
import weakref
from typing import Any, Callable, Optional, TypeVar
from asyncio import wait_for
from typing import Any, Optional, TypeVar
import pymongo._asyncio_lock
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
@ -28,6 +33,15 @@ _forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()
_T = TypeVar("_T")
# Needed to support 3.13 asyncio fixes (https://github.com/python/cpython/issues/112202)
# in older versions of Python
if sys.version_info >= (3, 13):
Lock = asyncio.Lock
Condition = asyncio.Condition
else:
Lock = pymongo._asyncio_lock.Lock
Condition = pymongo._asyncio_lock.Condition
def _create_lock() -> threading.Lock:
"""Represents a lock that is tracked upon instantiation using a WeakSet and
@ -39,6 +53,27 @@ def _create_lock() -> threading.Lock:
return lock
def _async_create_lock() -> Lock:
"""Represents an asyncio.Lock."""
return Lock()
def _create_condition(
lock: threading.Lock, condition_class: Optional[Any] = None
) -> threading.Condition:
"""Represents a threading.Condition."""
if condition_class:
return condition_class(lock)
return threading.Condition(lock)
def _async_create_condition(lock: Lock, condition_class: Optional[Any] = None) -> Condition:
"""Represents an asyncio.Condition."""
if condition_class:
return condition_class(lock)
return Condition(lock)
def _release_locks() -> None:
# Completed the fork, reset all the locks in the child.
for lock in _forkable_locks:
@ -46,202 +81,12 @@ def _release_locks() -> None:
lock.release()
# Needed only for synchro.py compat.
def _Lock(lock: threading.Lock) -> threading.Lock:
return lock
async def _async_cond_wait(condition: Condition, timeout: Optional[float]) -> bool:
try:
return await wait_for(condition.wait(), timeout)
except asyncio.TimeoutError:
return False
class _ALock:
__slots__ = ("_lock",)
def __init__(self, lock: threading.Lock) -> None:
self._lock = lock
def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
return self._lock.acquire(blocking=blocking, timeout=timeout)
async def a_acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
if timeout > 0:
tstart = time.monotonic()
while True:
acquired = self._lock.acquire(blocking=False)
if acquired:
return True
if timeout > 0 and (time.monotonic() - tstart) > timeout:
return False
if not blocking:
return False
await asyncio.sleep(0)
def release(self) -> None:
self._lock.release()
async def __aenter__(self) -> _ALock:
await self.a_acquire()
return self
def __enter__(self) -> _ALock:
self._lock.acquire()
return self
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()
def _safe_set_result(fut: asyncio.Future) -> None:
# Ensure the future hasn't been cancelled before calling set_result.
if not fut.done():
fut.set_result(False)
class _ACondition:
__slots__ = ("_condition", "_waiters")
def __init__(self, condition: threading.Condition) -> None:
self._condition = condition
self._waiters: collections.deque = collections.deque()
async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
if timeout > 0:
tstart = time.monotonic()
while True:
acquired = self._condition.acquire(blocking=False)
if acquired:
return True
if timeout > 0 and (time.monotonic() - tstart) > timeout:
return False
if not blocking:
return False
await asyncio.sleep(0)
async def wait(self, timeout: Optional[float] = None) -> bool:
"""Wait until notified.
If the calling task has not acquired the lock when this
method is called, a RuntimeError is raised.
This method releases the underlying lock, and then blocks
until it is awakened by a notify() or notify_all() call for
the same condition variable in another task. Once
awakened, it re-acquires the lock and returns True.
This method may return spuriously,
which is why the caller should always
re-check the state and be prepared to wait() again.
"""
loop = asyncio.get_running_loop()
fut = loop.create_future()
self._waiters.append((loop, fut))
self.release()
try:
try:
try:
await asyncio.wait_for(fut, timeout)
return True
except asyncio.TimeoutError:
return False # Return false on timeout for sync pool compat.
finally:
# Must re-acquire lock even if wait is cancelled.
# We only catch CancelledError here, since we don't want any
# other (fatal) errors with the future to cause us to spin.
err = None
while True:
try:
await self.acquire()
break
except asyncio.exceptions.CancelledError as e:
err = e
self._waiters.remove((loop, fut))
if err is not None:
try:
raise err # Re-raise most recent exception instance.
finally:
err = None # Break reference cycles.
except BaseException:
# Any error raised out of here _may_ have occurred after this Task
# believed to have been successfully notified.
# Make sure to notify another Task instead. This may result
# in a "spurious wakeup", which is allowed as part of the
# Condition Variable protocol.
self.notify(1)
raise
async def wait_for(self, predicate: Callable[[], _T]) -> _T:
"""Wait until a predicate becomes true.
The predicate should be a callable whose result will be
interpreted as a boolean value. The method will repeatedly
wait() until it evaluates to true. The final predicate value is
the return value.
"""
result = predicate()
while not result:
await self.wait()
result = predicate()
return result
def notify(self, n: int = 1) -> None:
"""By default, wake up one coroutine waiting on this condition, if any.
If the calling coroutine has not acquired the lock when this method
is called, a RuntimeError is raised.
This method wakes up at most n of the coroutines waiting for the
condition variable; it is a no-op if no coroutines are waiting.
Note: an awakened coroutine does not actually return from its
wait() call until it can reacquire the lock. Since notify() does
not release the lock, its caller should.
"""
idx = 0
to_remove = []
for loop, fut in self._waiters:
if idx >= n:
break
if fut.done():
continue
try:
loop.call_soon_threadsafe(_safe_set_result, fut)
except RuntimeError:
# Loop was closed, ignore.
to_remove.append((loop, fut))
continue
idx += 1
for waiter in to_remove:
self._waiters.remove(waiter)
def notify_all(self) -> None:
"""Wake up all threads waiting on this condition. This method acts
like notify(), but wakes up all waiting threads instead of one. If the
calling thread has not acquired the lock when this method is called,
a RuntimeError is raised.
"""
self.notify(len(self._waiters))
def locked(self) -> bool:
"""Only needed for tests in test_locks."""
return self._condition._lock.locked() # type: ignore[attr-defined]
def release(self) -> None:
self._condition.release()
async def __aenter__(self) -> _ACondition:
await self.acquire()
return self
def __enter__(self) -> _ACondition:
self._condition.acquire()
return self
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()
def _cond_wait(condition: threading.Condition, timeout: Optional[float]) -> bool:
return condition.wait(timeout)

View File

@ -28,7 +28,8 @@ from typing import (
Union,
)
from pymongo import _csot, ssl_support
from pymongo import ssl_support
from pymongo._asyncio_task import create_task
from pymongo.errors import _OperationCancelled
from pymongo.socket_checker import _errno_from_exception
@ -259,19 +260,20 @@ async def async_receive_data(
sock.settimeout(0.0)
loop = asyncio.get_event_loop()
cancellation_task = asyncio.create_task(_poll_cancellation(conn))
cancellation_task = create_task(_poll_cancellation(conn))
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
else:
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
tasks = [read_task, cancellation_task]
done, pending = await asyncio.wait(
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
await asyncio.wait(pending)
if pending:
await asyncio.wait(pending)
if len(done) == 0:
raise socket.timeout("timed out")
if read_task in done:
@ -314,62 +316,47 @@ async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLo
return mv
# Sync version:
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
timed_out = False
# Check if the connection's socket has been manually closed
if sock.fileno() == -1:
return
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
readable = True
else:
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled")
if readable:
return
if timed_out:
raise socket.timeout("timed out")
def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
wait_for_read(conn, deadline)
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
if _csot.get_timeout() and deadline is not None:
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except OSError as exc:
if _errno_from_exception(exc) == errno.EINTR:
# To support cancelling a network read, we shorten the socket timeout and
# check for the cancellation signal after each timeout. Alternatively we
# could close the socket but that does not reliably cancel recv() calls
# on all OSes.
orig_timeout = conn.conn.gettimeout()
try:
while bytes_read < length:
if deadline is not None:
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
short_timeout = min(max(deadline - time.monotonic(), 0), _POLL_TIMEOUT)
else:
short_timeout = _POLL_TIMEOUT
conn.set_conn_timeout(short_timeout)
try:
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled") from None
# We reached the true deadline.
raise socket.timeout("timed out") from None
except socket.timeout:
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled") from None
continue
raise
if chunk_length == 0:
raise OSError("connection closed")
except OSError as exc:
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled") from None
if _errno_from_exception(exc) == errno.EINTR:
continue
raise
if chunk_length == 0:
raise OSError("connection closed")
bytes_read += chunk_length
bytes_read += chunk_length
finally:
conn.set_conn_timeout(orig_timeout)
return mv

View File

@ -23,9 +23,102 @@ import time
import weakref
from typing import Any, Optional
from pymongo._asyncio_task import create_task
from pymongo.lock import _create_lock
_IS_SYNC = True
_IS_SYNC = False
class AsyncPeriodicExecutor:
def __init__(
self,
interval: float,
min_interval: float,
target: Any,
name: Optional[str] = None,
):
"""Run a target function periodically on a background task.
If the target's return value is false, the executor stops.
:param interval: Seconds between calls to `target`.
:param min_interval: Minimum seconds between calls if `wake` is
called very often.
:param target: A function.
:param name: A name to give the underlying task.
"""
self._event = False
self._interval = interval
self._min_interval = min_interval
self._target = target
self._stopped = False
self._task: Optional[asyncio.Task] = None
self._name = name
self._skip_sleep = False
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
def open(self) -> None:
"""Start. Multiple calls have no effect."""
self._stopped = False
if self._task is None or (
self._task.done() and not self._task.cancelled() and not self._task.cancelling() # type: ignore[unused-ignore, attr-defined]
):
self._task = create_task(self._run(), name=self._name)
def close(self, dummy: Any = None) -> None:
"""Stop. To restart, call open().
The dummy parameter allows an executor's close method to be a weakref
callback; see monitor.py.
"""
self._stopped = True
async def join(self, timeout: Optional[int] = None) -> None:
if self._task is not None:
try:
await asyncio.wait_for(self._task, timeout=timeout) # type-ignore: [arg-type]
except asyncio.TimeoutError:
# Task timed out
pass
except asyncio.exceptions.CancelledError:
# Task was already finished, or not yet started.
raise
def wake(self) -> None:
"""Execute the target function soon."""
self._event = True
def update_interval(self, new_interval: int) -> None:
self._interval = new_interval
def skip_sleep(self) -> None:
self._skip_sleep = True
async def _run(self) -> None:
while not self._stopped:
if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined]
raise asyncio.CancelledError
try:
if not await self._target():
self._stopped = True
break
except BaseException:
self._stopped = True
raise
if self._skip_sleep:
self._skip_sleep = False
else:
deadline = time.monotonic() + self._interval
while not self._stopped and time.monotonic() < deadline:
await asyncio.sleep(self._min_interval)
if self._event:
break # Early wake.
self._event = False
class PeriodicExecutor:
@ -64,19 +157,6 @@ class PeriodicExecutor:
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
def _run_async(self) -> None:
# The default asyncio loop implementation on Windows
# has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240)
# We explicitly use a different loop implementation here to prevent that issue
if sys.platform == "win32":
loop = asyncio.SelectorEventLoop()
try:
loop.run_until_complete(self._run()) # type: ignore[func-returns-value]
finally:
loop.close()
else:
asyncio.run(self._run()) # type: ignore[func-returns-value]
def open(self) -> None:
"""Start. Multiple calls have no effect.
@ -104,10 +184,7 @@ class PeriodicExecutor:
pass
if not started:
if _IS_SYNC:
thread = threading.Thread(target=self._run, name=self._name)
else:
thread = threading.Thread(target=self._run_async, name=self._name)
thread = threading.Thread(target=self._run, name=self._name)
thread.daemon = True
self._thread = weakref.proxy(thread)
_register_executor(self)

View File

@ -474,7 +474,6 @@ class _ClientBulk:
if op_type == "delete":
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
full_result[f"{op_type}Results"][original_index] = res
except Exception as exc:
# Attempt to close the cursor, then raise top-level error.
if cmd_cursor.alive:

View File

@ -77,7 +77,7 @@ class _ConnectionManager:
def __init__(self, conn: Connection, more_to_come: bool):
self.conn: Optional[Connection] = conn
self.more_to_come = more_to_come
self._alock = _create_lock()
self._lock = _create_lock()
def update_exhaust(self, more_to_come: bool) -> None:
self.more_to_come = more_to_come
@ -1297,7 +1297,7 @@ class Cursor(Generic[_DocumentType]):
>>> cursor.to_list()
Or, so read at most n items from the cursor::
Or, to read at most n items from the cursor::
>>> cursor.to_list(n)

View File

@ -15,6 +15,7 @@
"""Support for explicit client-side field level encryption."""
from __future__ import annotations
import asyncio
import contextlib
import enum
import socket
@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
# BSON encoding/decoding errors are unrelated to encryption so
# we should propagate them unchanged.
raise
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptionError(exc) from exc
@ -200,6 +203,8 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except asyncio.CancelledError:
raise
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
@ -716,6 +721,8 @@ class ClientEncryption(Generic[_DocumentType]):
database.create_collection(name=name, **kwargs),
encrypted_fields,
)
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptedCollectionError(exc, encrypted_fields) from exc

View File

@ -32,6 +32,7 @@ access:
"""
from __future__ import annotations
import asyncio
import contextlib
import os
import warnings
@ -58,7 +59,7 @@ from typing import (
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
from bson.timestamp import Timestamp
from pymongo import _csot, common, helpers_shared, uri_parser
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
from pymongo.client_options import ClientOptions
from pymongo.errors import (
AutoReconnect,
@ -74,7 +75,11 @@ from pymongo.errors import (
WaitQueueTimeoutError,
WriteConcernError,
)
from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks
from pymongo.lock import (
_HAS_REGISTER_AT_FORK,
_create_lock,
_release_locks,
)
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
from pymongo.message import _CursorAddress, _GetMore, _Query
from pymongo.monitoring import ConnectionClosedReason
@ -91,7 +96,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.results import ClientBulkWriteResult
from pymongo.server_selectors import writable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.synchronous import client_session, database, periodic_executor
from pymongo.synchronous import client_session, database
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.client_session import _EmptyServerSession
@ -1716,7 +1721,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
address=address,
)
with operation.conn_mgr._alock:
with operation.conn_mgr._lock:
with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
err_handler.contribute_socket(operation.conn_mgr.conn)
return server.run_operation(
@ -1964,7 +1969,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
try:
if conn_mgr:
with conn_mgr._alock:
with conn_mgr._lock:
# Cursor is pinned to LB outside of a transaction.
assert address is not None
assert conn_mgr.conn is not None
@ -2027,6 +2032,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_id, conn_mgr in pinned_cursors:
try:
self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
# Raise the exception when client is closed so that it
@ -2041,6 +2048,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_ids in address_to_cursor_ids.items():
try:
self._kill_cursors(cursor_ids, address, topology, session=None)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
raise
@ -2055,6 +2064,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
try:
self._process_kill_cursors()
self._topology.update_pool()
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
return

View File

@ -16,24 +16,24 @@
from __future__ import annotations
import asyncio
import atexit
import logging
import time
import weakref
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
from pymongo import common
from pymongo import common, periodic_executor
from pymongo._csot import MovingMinimum
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
from pymongo.hello import Hello
from pymongo.lock import _create_lock
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
from pymongo.periodic_executor import _shutdown_executors
from pymongo.pool_options import _is_faas
from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
from pymongo.srv_resolver import _SrvResolver
from pymongo.synchronous import periodic_executor
from pymongo.synchronous.periodic_executor import _shutdown_executors
if TYPE_CHECKING:
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext
@ -238,6 +238,9 @@ class Monitor(MonitorBase):
except ReferenceError:
# Topology was garbage-collected.
self.close()
finally:
if self._executor._stopped:
self._rtt_monitor.close()
def _check_server(self) -> ServerDescription:
"""Call hello or read the next streaming response.
@ -254,6 +257,8 @@ class Monitor(MonitorBase):
details = cast(Mapping[str, Any], exc.details)
self._topology.receive_cluster_time(details.get("$clusterTime"))
raise
except asyncio.CancelledError:
raise
except ReferenceError:
raise
except Exception as error:
@ -419,6 +424,8 @@ class SrvMonitor(MonitorBase):
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception
except asyncio.CancelledError:
raise
except Exception:
# As per the spec, upon encountering an error:
# - An error must not be raised
@ -482,6 +489,8 @@ class _RttMonitor(MonitorBase):
except ReferenceError:
# Topology was garbage-collected.
self.close()
except asyncio.CancelledError:
raise
except Exception:
self._pool.reset()
@ -536,4 +545,5 @@ def _shutdown_resources() -> None:
shutdown()
atexit.register(_shutdown_resources)
if _IS_SYNC:
atexit.register(_shutdown_resources)

View File

@ -23,7 +23,6 @@ import os
import socket
import ssl
import sys
import threading
import time
import weakref
from typing import (
@ -62,7 +61,11 @@ from pymongo.errors import ( # type:ignore[attr-defined]
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.lock import _create_lock, _Lock
from pymongo.lock import (
_cond_wait,
_create_condition,
_create_lock,
)
from pymongo.logger import (
_CONNECTION_LOGGER,
_ConnectionStatusMessage,
@ -208,11 +211,6 @@ def _raise_connection_failure(
raise AutoReconnect(msg) from error
def _cond_wait(condition: threading.Condition, deadline: Optional[float]) -> bool:
timeout = deadline - time.monotonic() if deadline else None
return condition.wait(timeout)
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
details = {}
timeout = _csot.get_timeout()
@ -704,6 +702,8 @@ class Connection:
# shutdown.
try:
self.conn.close()
except asyncio.CancelledError:
raise
except Exception: # noqa: S110
pass
@ -988,8 +988,8 @@ class Pool:
# from the right side.
self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
_lock = _create_lock()
self.lock = _Lock(_lock)
self.lock = _create_lock()
self._max_connecting_cond = _create_condition(self.lock)
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
@ -1015,7 +1015,7 @@ class Pool:
# The first portion of the wait queue.
# Enforces: maxPoolSize
# Also used for: clearing the wait queue
self.size_cond = threading.Condition(_lock)
self.size_cond = _create_condition(self.lock)
self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size:
@ -1023,7 +1023,7 @@ class Pool:
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = threading.Condition(_lock)
self._max_connecting_cond = _create_condition(self.lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._client_id = client_id
@ -1460,7 +1460,8 @@ class Pool:
with self.size_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=True)
while not (self.requests < self.max_pool_size):
if not _cond_wait(self.size_cond, deadline):
timeout = deadline - time.monotonic() if deadline else None
if not _cond_wait(self.size_cond, timeout):
# Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition.
if self.requests < self.max_pool_size:
@ -1483,7 +1484,8 @@ class Pool:
with self._max_connecting_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=False)
while not (self.conns or self._pending < self._max_connecting):
if not _cond_wait(self._max_connecting_cond, deadline):
timeout = deadline - time.monotonic() if deadline else None
if not _cond_wait(self._max_connecting_cond, timeout):
# Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition.
if self.conns or self._pending < self._max_connecting:

View File

@ -27,7 +27,7 @@ import weakref
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
from pymongo import _csot, common, helpers_shared
from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.errors import (
ConnectionFailure,
InvalidOperation,
@ -39,7 +39,11 @@ from pymongo.errors import (
WriteError,
)
from pymongo.hello import Hello
from pymongo.lock import _create_lock, _Lock
from pymongo.lock import (
_cond_wait,
_create_condition,
_create_lock,
)
from pymongo.logger import (
_SDAM_LOGGER,
_SERVER_SELECTION_LOGGER,
@ -56,7 +60,6 @@ from pymongo.server_selectors import (
secondary_server_selector,
writable_server_selector,
)
from pymongo.synchronous import periodic_executor
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.synchronous.monitor import SrvMonitor
from pymongo.synchronous.pool import Pool
@ -170,9 +173,10 @@ class Topology:
self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False
self._closed = False
_lock = _create_lock()
self._lock = _Lock(_lock)
self._condition = self._settings.condition_class(_lock)
self._lock = _create_lock()
self._condition = _create_condition(
self._lock, self._settings.condition_class if _IS_SYNC else None
)
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None
@ -354,7 +358,7 @@ class Topology:
# change, or for a timeout. We won't miss any changes that
# came after our most recent apply_selector call, since we've
# held the lock until now.
self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL)
self._description.check_compatible()
now = time.monotonic()
server_descriptions = self._description.apply_selector(
@ -652,7 +656,7 @@ class Topology:
"""Wake all monitors, wait for at least one to check its server."""
with self._lock:
self._request_check_all()
self._condition.wait(wait_time)
_cond_wait(self._condition, wait_time)
def data_bearing_servers(self) -> list[ServerDescription]:
"""Return a list of all data-bearing servers.

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import asyncio
import gc
import logging
import multiprocessing
import os
import signal
@ -25,6 +26,7 @@ import subprocess
import sys
import threading
import time
import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
@ -191,6 +193,8 @@ class ClientContext:
client.close()
def _init_client(self):
self.mongoses = []
self.connection_attempts = []
self.client = self._connect(host, port)
if self.client is not None:
# Return early when connected to dataLake as mongohoused does not
@ -860,6 +864,16 @@ class ClientContext:
client_context = ClientContext()
def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
elif client_context.client is not None:
client_context.client.close()
client_context.client = None
client_context._init_client()
class PyMongoTestCase(unittest.TestCase):
def assertEqualCommand(self, expected, actual, msg=None):
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
@ -1106,26 +1120,10 @@ class PyMongoTestCase(unittest.TestCase):
class UnitTest(PyMongoTestCase):
"""Async base class for TestCases that don't require a connection to MongoDB."""
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
def _setup_class(cls):
def setUp(self) -> None:
pass
@classmethod
def _tearDown_class(cls):
def tearDown(self) -> None:
pass
@ -1136,37 +1134,20 @@ class IntegrationTest(PyMongoTestCase):
db: Database
credentials: Dict[str, str]
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@client_context.require_connection
def _setup_class(cls):
if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
def setUp(self) -> None:
if not _IS_SYNC:
reset_client_context()
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers")
if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
raise SkipTest("this test does not support serverless")
cls.client = client_context.client
cls.db = cls.client.pymongo_test
self.client = client_context.client
self.db = self.client.pymongo_test
if client_context.auth_enabled:
cls.credentials = {"username": db_user, "password": db_pwd}
self.credentials = {"username": db_user, "password": db_pwd}
else:
cls.credentials = {}
@classmethod
def _tearDown_class(cls):
pass
self.credentials = {}
def cleanup_colls(self, *collections):
"""Cleanup collections faster than drop_collection."""
@ -1192,37 +1173,14 @@ class MockClientTest(UnitTest):
# MockClients tests that use replicaSet, directConnection=True, pass
# multiple seed addresses, or wait for heartbeat events are incompatible
# with loadBalanced=True.
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@client_context.require_no_load_balancer
def _setup_class(cls):
pass
@classmethod
def _tearDown_class(cls):
pass
def setUp(self):
def setUp(self) -> None:
super().setUp()
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
self.client_knobs.enable()
def tearDown(self):
def tearDown(self) -> None:
self.client_knobs.disable()
super().tearDown()
@ -1253,7 +1211,6 @@ def teardown():
c.drop_database("pymongo_test_mike")
c.drop_database("pymongo_test_bernie")
c.close()
print_running_clients()

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import asyncio
import gc
import logging
import multiprocessing
import os
import signal
@ -25,6 +26,7 @@ import subprocess
import sys
import threading
import time
import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
@ -191,6 +193,8 @@ class AsyncClientContext:
await client.close()
async def _init_client(self):
self.mongoses = []
self.connection_attempts = []
self.client = await self._connect(host, port)
if self.client is not None:
# Return early when connected to dataLake as mongohoused does not
@ -862,6 +866,16 @@ class AsyncClientContext:
async_client_context = AsyncClientContext()
async def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
elif async_client_context.client is not None:
await async_client_context.client.close()
async_client_context.client = None
await async_client_context._init_client()
class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
def assertEqualCommand(self, expected, actual, msg=None):
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
@ -1124,26 +1138,10 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
class AsyncUnitTest(AsyncPyMongoTestCase):
"""Async base class for TestCases that don't require a connection to MongoDB."""
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
async def _setup_class(cls):
async def asyncSetUp(self) -> None:
pass
@classmethod
async def _tearDown_class(cls):
async def asyncTearDown(self) -> None:
pass
@ -1154,37 +1152,20 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
db: AsyncDatabase
credentials: Dict[str, str]
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@async_client_context.require_connection
async def _setup_class(cls):
if async_client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
async def asyncSetUp(self) -> None:
if not _IS_SYNC:
await reset_client_context()
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers")
if async_client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
raise SkipTest("this test does not support serverless")
cls.client = async_client_context.client
cls.db = cls.client.pymongo_test
self.client = async_client_context.client
self.db = self.client.pymongo_test
if async_client_context.auth_enabled:
cls.credentials = {"username": db_user, "password": db_pwd}
self.credentials = {"username": db_user, "password": db_pwd}
else:
cls.credentials = {}
@classmethod
async def _tearDown_class(cls):
pass
self.credentials = {}
async def cleanup_colls(self, *collections):
"""Cleanup collections faster than drop_collection."""
@ -1210,39 +1191,16 @@ class AsyncMockClientTest(AsyncUnitTest):
# MockClients tests that use replicaSet, directConnection=True, pass
# multiple seed addresses, or wait for heartbeat events are incompatible
# with loadBalanced=True.
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@async_client_context.require_no_load_balancer
async def _setup_class(cls):
pass
@classmethod
async def _tearDown_class(cls):
pass
def setUp(self):
super().setUp()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
self.client_knobs.enable()
def tearDown(self):
async def asyncTearDown(self) -> None:
self.client_knobs.disable()
super().tearDown()
await super().asyncTearDown()
async def async_setup():
@ -1271,7 +1229,6 @@ async def async_teardown():
await c.drop_database("pymongo_test_mike")
await c.drop_database("pymongo_test_bernie")
await c.close()
print_running_clients()

View File

@ -22,7 +22,7 @@ def event_loop_policy():
return asyncio.get_event_loop_policy()
@pytest_asyncio.fixture(scope="session", autouse=True)
@pytest_asyncio.fixture(scope="package", autouse=True)
async def test_setup_and_teardown():
await async_setup()
yield

View File

@ -42,15 +42,11 @@ class AsyncBulkTestBase(AsyncIntegrationTest):
coll: AsyncCollection
coll_w0: AsyncCollection
@classmethod
async def _setup_class(cls):
await super()._setup_class()
cls.coll = cls.db.test
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
async def asyncSetUp(self):
super().setUp()
await super().asyncSetUp()
self.coll = self.db.test
await self.coll.drop()
self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0))
def assertEqualResponse(self, expected, actual):
"""Compare response from bulk.execute() to expected response."""
@ -787,14 +783,10 @@ class AsyncTestBulk(AsyncBulkTestBase):
class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase):
@classmethod
@async_client_context.require_auth
@async_client_context.require_no_api_version
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self):
super().setUp()
await super().asyncSetUp()
await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"])
await self.db.command(
"createRole",
@ -937,21 +929,19 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
w: Optional[int]
secondary: AsyncMongoClient
@classmethod
async def _setup_class(cls):
await super()._setup_class()
cls.w = async_client_context.w
cls.secondary = None
if cls.w is not None and cls.w > 1:
async def asyncSetUp(self):
await super().asyncSetUp()
self.w = async_client_context.w
self.secondary = None
if self.w is not None and self.w > 1:
for member in (await async_client_context.hello)["hosts"]:
if member != (await async_client_context.hello)["primary"]:
cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member))
self.secondary = await self.async_single_client(*partition_node(member))
break
@classmethod
async def async_tearDownClass(cls):
if cls.secondary:
await cls.secondary.close()
async def asyncTearDown(self):
if self.secondary:
await self.secondary.close()
async def cause_wtimeout(self, requests, ordered):
if not async_client_context.test_commands_enabled:

View File

@ -836,18 +836,16 @@ class ProseSpecTestsMixin:
class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
dbs: list
@classmethod
@async_client_context.require_version_min(4, 0, 0, -1)
@async_client_context.require_change_streams
async def _setup_class(cls):
await super()._setup_class()
cls.dbs = [cls.db, cls.client.pymongo_test_2]
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.dbs = [self.db, self.client.pymongo_test_2]
@classmethod
async def _tearDown_class(cls):
for db in cls.dbs:
await cls.client.drop_database(db)
await super()._tearDown_class()
async def asyncTearDown(self):
for db in self.dbs:
await self.client.drop_database(db)
await super().asyncTearDown()
async def change_stream_with_client(self, client, *args, **kwargs):
return await client.watch(*args, **kwargs)
@ -898,11 +896,10 @@ class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
@classmethod
@async_client_context.require_version_min(4, 0, 0, -1)
@async_client_context.require_change_streams
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
async def change_stream_with_client(self, client, *args, **kwargs):
return await client[self.db.name].watch(*args, **kwargs)
@ -988,12 +985,9 @@ class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixi
class TestAsyncCollectionAsyncChangeStream(
TestAsyncChangeStreamBase, APITestsMixin, ProseSpecTestsMixin
):
@classmethod
@async_client_context.require_change_streams
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self):
await super().asyncSetUp()
# Use a new collection for each test.
await self.watched_collection().drop()
await self.watched_collection().insert_one({})
@ -1133,20 +1127,11 @@ class TestAllLegacyScenarios(AsyncIntegrationTest):
RUN_ON_LOAD_BALANCER = True
listener: AllowListEventListener
@classmethod
@async_client_context.require_connection
async def _setup_class(cls):
await super()._setup_class()
cls.listener = AllowListEventListener("aggregate", "getMore")
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
await super()._tearDown_class()
def asyncSetUp(self):
super().asyncSetUp()
async def asyncSetUp(self):
await super().asyncSetUp()
self.listener = AllowListEventListener("aggregate", "getMore")
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
self.listener.reset()
async def asyncSetUpCluster(self, scenario_dict):

View File

@ -73,7 +73,6 @@ from test.utils import (
is_greenthread_patched,
lazy_client_trial,
one,
wait_until,
)
import bson
@ -131,16 +130,11 @@ class AsyncClientUnitTest(AsyncUnitTest):
client: AsyncMongoClient
@classmethod
async def _setup_class(cls):
cls.client = await cls.unmanaged_async_rs_or_single_client(
async def asyncSetUp(self) -> None:
self.client = await self.async_rs_or_single_client(
connect=False, serverSelectionTimeoutMS=100
)
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
@ -693,8 +687,8 @@ class TestClient(AsyncIntegrationTest):
# When the reaper runs at the same time as the get_socket, two
# connections could be created and checked into the pool.
self.assertGreaterEqual(len(server._pool.conns), 1)
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket")
await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
await async_wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket")
async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self):
with client_knobs(kill_cursor_frequency=0.1):
@ -710,8 +704,8 @@ class TestClient(AsyncIntegrationTest):
# When the reaper runs at the same time as the get_socket,
# maxPoolSize=1 should prevent two connections from being created.
self.assertEqual(1, len(server._pool.conns))
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket")
await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
await async_wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket")
async def test_max_idle_time_reaper_removes_stale(self):
with client_knobs(kill_cursor_frequency=0.1):
@ -727,7 +721,7 @@ class TestClient(AsyncIntegrationTest):
async with server._pool.checkout() as conn_two:
pass
self.assertIs(conn_one, conn_two)
wait_until(
await async_wait_until(
lambda: len(server._pool.conns) == 0,
"stale socket reaped and new one NOT added to the pool",
)
@ -745,7 +739,7 @@ class TestClient(AsyncIntegrationTest):
server = await (await client._get_topology()).select_server(
readable_server_selector, _Op.TEST
)
wait_until(
await async_wait_until(
lambda: len(server._pool.conns) == 10,
"pool initialized with 10 connections",
)
@ -753,7 +747,7 @@ class TestClient(AsyncIntegrationTest):
# Assert that if a socket is closed, a new one takes its place
async with server._pool.checkout() as conn:
conn.close_conn(None)
wait_until(
await async_wait_until(
lambda: len(server._pool.conns) == 10,
"a closed socket gets replaced from the pool",
)
@ -939,8 +933,10 @@ class TestClient(AsyncIntegrationTest):
async with eval(the_repr) as client_two:
self.assertEqual(client_two, client)
def test_getters(self):
wait_until(lambda: async_client_context.nodes == self.client.nodes, "find all nodes")
async def test_getters(self):
await async_wait_until(
lambda: async_client_context.nodes == self.client.nodes, "find all nodes"
)
async def test_list_databases(self):
cmd_docs = (await self.client.admin.command("listDatabases"))["databases"]
@ -1065,14 +1061,21 @@ class TestClient(AsyncIntegrationTest):
self.assertFalse(client._topology._opened)
# Ensure kill cursors thread has not been started.
kc_thread = client._kill_cursors_executor._thread
self.assertFalse(kc_thread and kc_thread.is_alive())
if _IS_SYNC:
kc_thread = client._kill_cursors_executor._thread
self.assertFalse(kc_thread and kc_thread.is_alive())
else:
kc_task = client._kill_cursors_executor._task
self.assertFalse(kc_task and not kc_task.done())
# Using the client should open topology and start the thread.
await client.admin.command("ping")
self.assertTrue(client._topology._opened)
kc_thread = client._kill_cursors_executor._thread
self.assertTrue(kc_thread and kc_thread.is_alive())
if _IS_SYNC:
kc_thread = client._kill_cursors_executor._thread
self.assertTrue(kc_thread and kc_thread.is_alive())
else:
kc_task = client._kill_cursors_executor._task
self.assertTrue(kc_task and not kc_task.done())
async def test_close_does_not_open_servers(self):
client = await self.async_rs_client(connect=False)
@ -1277,6 +1280,7 @@ class TestClient(AsyncIntegrationTest):
async def test_server_selection_timeout(self):
client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
await client.close()
client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False)
@ -1289,18 +1293,22 @@ class TestClient(AsyncIntegrationTest):
self.assertRaises(
ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False
)
await client.close()
client = AsyncMongoClient(
"mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False
)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
await client.close()
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
self.assertAlmostEqual(0, client.options.server_selection_timeout)
await client.close()
# Test invalid timeout in URI ignored and set to default.
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout)
await client.close()
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout)
@ -1608,7 +1616,7 @@ class TestClient(AsyncIntegrationTest):
await async_client_context.port,
)
await self.async_single_client(uri, event_listeners=[listener])
wait_until(
await async_wait_until(
lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents"
)
@ -1766,16 +1774,16 @@ class TestClient(AsyncIntegrationTest):
pool = await async_get_pool(client)
original_connect = pool.connect
def stall_connect(*args, **kwargs):
time.sleep(2)
return original_connect(*args, **kwargs)
async def stall_connect(*args, **kwargs):
await asyncio.sleep(2)
return await original_connect(*args, **kwargs)
pool.connect = stall_connect
# Un-patch Pool.connect to break the cyclic reference.
self.addCleanup(delattr, pool, "connect")
# Wait for the background thread to start creating connections
wait_until(lambda: len(pool.conns) > 1, "start creating connections")
await async_wait_until(lambda: len(pool.conns) > 1, "start creating connections")
# Assert that application operations do not block.
for _ in range(10):
@ -1858,7 +1866,7 @@ class TestClient(AsyncIntegrationTest):
await client.close()
# Add cursor to kill cursors queue
del cursor
wait_until(
await async_wait_until(
lambda: client._kill_cursors_queue,
"waited for cursor to be added to queue",
)
@ -2232,7 +2240,7 @@ class TestExhaustCursor(AsyncIntegrationTest):
await cursor.to_list()
self.assertTrue(conn.closed)
wait_until(
await async_wait_until(
lambda: len(client._kill_cursors_queue) == 0,
"waited for all killCursor requests to complete",
)
@ -2403,7 +2411,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
)
self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 3, "connect")
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
self.assertEqual(await c.address, ("a", 1))
# Fail over.
@ -2430,7 +2438,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
)
self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 3, "connect")
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
# Total failure.
c.kill_host("a:1")
@ -2472,7 +2480,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
c.set_wire_version_range("a:1", 2, MIN_SUPPORTED_WIRE_VERSION)
c.set_wire_version_range("b:2", 2, MIN_SUPPORTED_WIRE_VERSION + 1)
await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
wait_until(lambda: len(c.nodes) == 2, "connect")
await async_wait_until(lambda: len(c.nodes) == 2, "connect")
c.kill_host("a:1")
@ -2544,11 +2552,11 @@ class TestClientPool(AsyncMockClientTest):
)
self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 3, "connect")
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
self.assertEqual(await c.address, ("a", 1))
self.assertEqual(await c.arbiters, {("c", 3)})
# Assert that we create 2 and only 2 pooled connections.
listener.wait_for_event(monitoring.ConnectionReadyEvent, 2)
await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 2)
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2)
# Assert that we do not create connections to arbiters.
arbiter = c._topology.get_server_by_address(("c", 3))
@ -2574,10 +2582,10 @@ class TestClientPool(AsyncMockClientTest):
)
self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 1, "connect")
await async_wait_until(lambda: len(c.nodes) == 1, "connect")
self.assertEqual(await c.address, ("c", 3))
# Assert that we create 1 pooled connection.
listener.wait_for_event(monitoring.ConnectionReadyEvent, 1)
await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 1)
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1)
arbiter = c._topology.get_server_by_address(("c", 3))
self.assertEqual(len(arbiter.pool.conns), 1)

View File

@ -25,7 +25,7 @@ _IS_SYNC = False
class TestAsyncClientContext(AsyncUnitTest):
def test_must_connect(self):
if "PYMONGO_MUST_CONNECT" not in os.environ:
if not os.environ.get("PYMONGO_MUST_CONNECT"):
raise SkipTest("PYMONGO_MUST_CONNECT is not set")
self.assertTrue(
@ -37,7 +37,7 @@ class TestAsyncClientContext(AsyncUnitTest):
)
def test_serverless(self):
if "TEST_SERVERLESS" not in os.environ:
if not os.environ.get("TEST_SERVERLESS"):
raise SkipTest("TEST_SERVERLESS is not set")
self.assertTrue(
@ -47,7 +47,7 @@ class TestAsyncClientContext(AsyncUnitTest):
)
def test_enableTestCommands_is_disabled(self):
if "PYMONGO_DISABLE_TEST_COMMANDS" not in os.environ:
if not os.environ.get("PYMONGO_DISABLE_TEST_COMMANDS"):
raise SkipTest("PYMONGO_DISABLE_TEST_COMMANDS is not set")
self.assertFalse(
@ -56,7 +56,7 @@ class TestAsyncClientContext(AsyncUnitTest):
)
def test_setdefaultencoding_worked(self):
if "SETDEFAULTENCODING" not in os.environ:
if not os.environ.get("SETDEFAULTENCODING"):
raise SkipTest("SETDEFAULTENCODING is not set")
self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"])

View File

@ -97,28 +97,21 @@ class TestCollation(AsyncIntegrationTest):
warn_context: Any
collation: Collation
@classmethod
@async_client_context.require_connection
async def _setup_class(cls):
await super()._setup_class()
cls.listener = OvertCommandListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
cls.db = cls.client.pymongo_test
cls.collation = Collation("en_US")
cls.warn_context = warnings.catch_warnings()
cls.warn_context.__enter__()
warnings.simplefilter("ignore", DeprecationWarning)
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.listener = OvertCommandListener()
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
self.db = self.client.pymongo_test
self.collation = Collation("en_US")
self.warn_context = warnings.catch_warnings()
self.warn_context.__enter__()
@classmethod
async def _tearDown_class(cls):
cls.warn_context.__exit__()
cls.warn_context = None
await cls.client.close()
await super()._tearDown_class()
def tearDown(self):
async def asyncTearDown(self) -> None:
self.warn_context.__exit__()
self.warn_context = None
self.listener.reset()
super().tearDown()
await super().asyncTearDown()
def last_command_started(self):
return self.listener.started_events[-1].command

View File

@ -40,7 +40,6 @@ from test.utils import (
async_get_pool,
async_is_mongos,
async_wait_until,
wait_until,
)
from bson import encode
@ -88,14 +87,10 @@ class TestCollectionNoConnect(AsyncUnitTest):
db: AsyncDatabase
client: AsyncMongoClient
@classmethod
async def _setup_class(cls):
cls.client = AsyncMongoClient(connect=False)
cls.db = cls.client.pymongo_test
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.client = self.simple_client(connect=False)
self.db = self.client.pymongo_test
def test_collection(self):
self.assertRaises(TypeError, AsyncCollection, self.db, 5)
@ -165,27 +160,14 @@ class TestCollectionNoConnect(AsyncUnitTest):
class AsyncTestCollection(AsyncIntegrationTest):
w: int
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.w = async_client_context.w # type: ignore
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine]
else:
asyncio.run(cls.async_tearDownClass())
@classmethod
async def async_tearDownClass(cls):
await cls.db.drop_collection("test_large_limit")
async def asyncSetUp(self):
await self.db.test.drop()
await super().asyncSetUp()
self.w = async_client_context.w # type: ignore
async def asyncTearDown(self):
await self.db.test.drop()
await self.db.drop_collection("test_large_limit")
await super().asyncTearDown()
@contextlib.contextmanager
def write_concern_collection(self):
@ -1023,7 +1005,10 @@ class AsyncTestCollection(AsyncIntegrationTest):
await db.test.insert_one({"y": 1}, bypass_document_validation=True)
await db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True)
await async_wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document")
async def predicate():
return await db_w0.test.find_one({"x": 1})
await async_wait_until(predicate, "find w:0 replaced document")
async def test_update_bypass_document_validation(self):
db = self.db
@ -1871,7 +1856,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
await cur.close()
cur = None
# Wait until the background thread returns the socket.
wait_until(lambda: pool.active_sockets == 0, "return socket")
await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
# The socket should be discarded.
self.assertEqual(0, len(pool.conns))

View File

@ -19,7 +19,12 @@ import sys
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.asynchronous import (
AsyncIntegrationTest,
async_client_context,
reset_client_context,
unittest,
)
from test.asynchronous.helpers import async_repl_set_step_down
from test.utils import (
CMAPListener,
@ -39,29 +44,19 @@ class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest):
listener: CMAPListener
coll: AsyncCollection
@classmethod
@async_client_context.require_replica_set
async def _setup_class(cls):
await super()._setup_class()
cls.listener = CMAPListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(
event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500
async def asyncSetUp(self):
self.listener = CMAPListener()
self.client = await self.async_rs_or_single_client(
event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500
)
# Ensure connections to all servers in replica set. This is to test
# that the is_writable flag is properly updated for connections that
# survive a replica set election.
await async_ensure_all_connected(cls.client)
cls.listener.reset()
cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority"))
cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority"))
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
async def asyncSetUp(self):
await async_ensure_all_connected(self.client)
self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority"))
self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority"))
# Note that all ops use same write-concern as self.db (majority).
await self.db.drop_collection("step-down")
await self.db.create_collection("step-down")

View File

@ -56,6 +56,9 @@ class TestCreateEntities(AsyncIntegrationTest):
self.assertGreater(len(final_entity_map["events1"]), 0)
for event in final_entity_map["events1"]:
self.assertIn("PoolCreatedEvent", event["name"])
if self.scenario_runner.mongos_clients:
for client in self.scenario_runner.mongos_clients:
await client.close()
async def test_store_all_others_as_entities(self):
self.scenario_runner = UnifiedSpecTestMixinV1()
@ -122,6 +125,9 @@ class TestCreateEntities(AsyncIntegrationTest):
self.assertEqual(entity_map["failures"], [])
self.assertEqual(entity_map["successes"], 2)
self.assertEqual(entity_map["iterations"], 5)
if self.scenario_runner.mongos_clients:
for client in self.scenario_runner.mongos_clients:
await client.close()
if __name__ == "__main__":

View File

@ -34,9 +34,9 @@ from test.utils import (
AllowListEventListener,
EventListener,
OvertCommandListener,
async_wait_until,
delay,
ignore_deprecations,
wait_until,
)
from bson import decode_all
@ -1324,8 +1324,8 @@ class TestCursor(AsyncIntegrationTest):
with self.assertRaises(ExecutionTimeout):
await cursor.next()
def assertCursorKilled():
wait_until(
async def assertCursorKilled():
await async_wait_until(
lambda: len(listener.succeeded_events),
"find successful killCursors command",
)
@ -1335,7 +1335,7 @@ class TestCursor(AsyncIntegrationTest):
self.assertEqual(1, len(listener.succeeded_events))
self.assertEqual("killCursors", listener.succeeded_events[0].command_name)
assertCursorKilled()
await assertCursorKilled()
listener.reset()
cursor = await coll.aggregate([], batchSize=1)
@ -1345,7 +1345,7 @@ class TestCursor(AsyncIntegrationTest):
with self.assertRaises(ExecutionTimeout):
await cursor.next()
assertCursorKilled()
await assertCursorKilled()
def test_delete_not_initialized(self):
# Creating a cursor with invalid arguments will not run __init__
@ -1647,10 +1647,6 @@ class TestRawBatchCursor(AsyncIntegrationTest):
class TestRawBatchCommandCursor(AsyncIntegrationTest):
@classmethod
async def _setup_class(cls):
await super()._setup_class()
async def test_aggregate_raw(self):
c = self.db.test
await c.drop()

View File

@ -717,7 +717,8 @@ class TestDatabase(AsyncIntegrationTest):
class TestDatabaseAggregation(AsyncIntegrationTest):
def setUp(self):
async def asyncSetUp(self):
await super().asyncSetUp()
self.pipeline: List[Mapping[str, Any]] = [
{"$listLocalSessions": {}},
{"$limit": 1},

View File

@ -211,11 +211,10 @@ class TestClientOptions(AsyncPyMongoTestCase):
class AsyncEncryptionIntegrationTest(AsyncIntegrationTest):
"""Base class for encryption integration tests."""
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
@async_client_context.require_version_min(4, 2, -1)
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
def assertEncrypted(self, val):
self.assertIsInstance(val, Binary)
@ -430,10 +429,9 @@ class TestEncryptedBulkWrite(AsyncBulkTestBase, AsyncEncryptionIntegrationTest):
class TestClientMaxWireVersion(AsyncIntegrationTest):
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self):
await super().asyncSetUp()
@async_client_context.require_version_max(4, 0, 99)
async def test_raise_max_wire_version_error(self):
@ -818,17 +816,16 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
"local": None,
}
@classmethod
@unittest.skipUnless(
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
"No environment credentials are set",
)
async def _setup_class(cls):
await super()._setup_class()
cls.listener = OvertCommandListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
await cls.client.db.coll.drop()
cls.vault = await create_key_vault(cls.client.keyvault.datakeys)
async def asyncSetUp(self):
await super().asyncSetUp()
self.listener = OvertCommandListener()
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
await self.client.db.coll.drop()
self.vault = await create_key_vault(self.client.keyvault.datakeys)
# Configure the encrypted field via the local schema_map option.
schemas = {
@ -846,25 +843,22 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
}
}
opts = AutoEncryptionOpts(
cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS
self.KMS_PROVIDERS,
"keyvault.datakeys",
schema_map=schemas,
kms_tls_options=KMS_TLS_OPTS,
)
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client(
self.client_encrypted = await self.async_rs_or_single_client(
auto_encryption_opts=opts, uuidRepresentation="standard"
)
cls.client_encryption = cls.unmanaged_create_client_encryption(
cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS
self.client_encryption = self.create_client_encryption(
self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS
)
@classmethod
async def _tearDown_class(cls):
await cls.vault.drop()
await cls.client.close()
await cls.client_encrypted.close()
await cls.client_encryption.close()
def setUp(self):
self.listener.reset()
async def asyncTearDown(self) -> None:
await self.vault.drop()
async def run_test(self, provider_name):
# Create data key.
master_key: Any = self.MASTER_KEYS[provider_name]
@ -1011,10 +1005,9 @@ class TestViews(AsyncEncryptionIntegrationTest):
class TestCorpus(AsyncEncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self):
await super().asyncSetUp()
@staticmethod
def kms_providers():
@ -1188,12 +1181,11 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
client_encrypted: AsyncMongoClient
listener: OvertCommandListener
@classmethod
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self):
await super().asyncSetUp()
db = async_client_context.client.db
cls.coll = db.coll
await cls.coll.drop()
self.coll = db.coll
await self.coll.drop()
# Configure the encrypted 'db.coll' collection via jsonSchema.
json_schema = json_data("limits", "limits-schema.json")
await db.create_collection(
@ -1211,17 +1203,14 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
await coll.insert_one(json_data("limits", "limits-key.json"))
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
cls.listener = OvertCommandListener()
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client(
auto_encryption_opts=opts, event_listeners=[cls.listener]
self.listener = OvertCommandListener()
self.client_encrypted = await self.async_rs_or_single_client(
auto_encryption_opts=opts, event_listeners=[self.listener]
)
cls.coll_encrypted = cls.client_encrypted.db.coll
self.coll_encrypted = self.client_encrypted.db.coll
@classmethod
async def _tearDown_class(cls):
await cls.coll_encrypted.drop()
await cls.client_encrypted.close()
await super()._tearDown_class()
async def asyncTearDown(self) -> None:
await self.coll_encrypted.drop()
async def test_01_insert_succeeds_under_2MiB(self):
doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB}
@ -1245,7 +1234,9 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB}
self.listener.reset()
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
self.assertEqual(
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
)
async def test_04_bulk_batch_split(self):
limits_doc = json_data("limits", "limits-doc.json")
@ -1255,7 +1246,9 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
doc2.update(limits_doc)
self.listener.reset()
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
self.assertEqual(
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
)
async def test_05_insert_succeeds_just_under_16MiB(self):
doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)}
@ -1285,15 +1278,12 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
"""Prose tests for creating data keys with a custom endpoint."""
@classmethod
@unittest.skipUnless(
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
"No environment credentials are set",
)
async def _setup_class(cls):
await super()._setup_class()
def setUp(self):
async def asyncSetUp(self):
await super().asyncSetUp()
kms_providers = {
"aws": AWS_CREDS,
"azure": AZURE_CREDS,
@ -1322,10 +1312,6 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
self._kmip_host_error = None
self._invalid_host_error = None
async def asyncTearDown(self):
await self.client_encryption.close()
await self.client_encryption_invalid.close()
async def run_test_expected_success(self, provider_name, master_key):
data_key_id = await self.client_encryption.create_data_key(
provider_name, master_key=master_key
@ -1500,18 +1486,18 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
KEYVAULT_COLL = "datakeys"
client: AsyncMongoClient
async def asyncSetUp(self):
async def _setup(self):
keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL)
await create_key_vault(keyvault, self.DEK)
async def _test_explicit(self, expectation):
await self._setup()
client_encryption = self.create_client_encryption(
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
async_client_context.client,
OPTS,
)
self.addAsyncCleanup(client_encryption.close)
ciphertext = await client_encryption.encrypt(
"string0",
@ -1523,6 +1509,7 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
self.assertEqual(await client_encryption.decrypt(ciphertext), "string0")
async def _test_automatic(self, expectation_extjson, payload):
await self._setup()
encrypted_db = "db"
encrypted_coll = "coll"
keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
@ -1537,7 +1524,6 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
client = await self.async_rs_or_single_client(
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
)
self.addAsyncCleanup(client.aclose)
coll = client.get_database(encrypted_db).get_collection(
encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority")
@ -1559,13 +1545,12 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set")
async def _setup_class(cls):
cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
cls.DEK = json_data(BASE, "custom", "azure-dek.json")
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
await super()._setup_class()
async def asyncSetUp(self):
self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
self.DEK = json_data(BASE, "custom", "azure-dek.json")
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
await super().asyncSetUp()
async def test_explicit(self):
return await self._test_explicit(
@ -1585,13 +1570,12 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegratio
class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set")
async def _setup_class(cls):
cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
cls.DEK = json_data(BASE, "custom", "gcp-dek.json")
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
await super()._setup_class()
async def asyncSetUp(self):
self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
self.DEK = json_data(BASE, "custom", "gcp-dek.json")
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
await super().asyncSetUp()
async def test_explicit(self):
return await self._test_explicit(
@ -1613,6 +1597,7 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationT
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests
class TestDeadlockProse(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self):
await super().asyncSetUp()
self.client_test = await self.async_rs_or_single_client(
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
)
@ -1645,7 +1630,6 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
self.ciphertext = await client_encryption.encrypt(
"string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local"
)
await client_encryption.close()
self.client_listener = OvertCommandListener()
self.topology_listener = TopologyEventListener()
@ -1840,6 +1824,7 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events
class TestDecryptProse(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self):
await super().asyncSetUp()
self.client = async_client_context.client
await self.client.db.drop_collection("decryption_events")
await create_key_vault(self.client.keyvault.datakeys)
@ -2275,6 +2260,7 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest):
# https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames
class TestUniqueIndexOnKeyAltNamesProse(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self):
await super().asyncSetUp()
self.client = async_client_context.client
await create_key_vault(self.client.keyvault.datakeys)
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
@ -2624,8 +2610,6 @@ class TestQueryableEncryptionDocsExample(AsyncEncryptionIntegrationTest):
assert isinstance(res["encrypted_indexed"], Binary)
assert isinstance(res["encrypted_unindexed"], Binary)
await client_encryption.close()
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption
class TestRangeQueryProse(AsyncEncryptionIntegrationTest):
@ -3085,21 +3069,15 @@ def start_mongocryptd(port) -> None:
_spawn_daemon(args)
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
class TestNoSessionsSupport(AsyncEncryptionIntegrationTest):
mongocryptd_client: AsyncMongoClient
MONGOCRYPTD_PORT = 27020
@classmethod
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
async def _setup_class(cls):
await super()._setup_class()
start_mongocryptd(cls.MONGOCRYPTD_PORT)
@classmethod
async def _tearDown_class(cls):
await super()._tearDown_class()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
start_mongocryptd(self.MONGOCRYPTD_PORT)
self.listener = OvertCommandListener()
self.mongocryptd_client = self.simple_client(
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]

View File

@ -97,6 +97,7 @@ class AsyncTestGridFileNoConnect(AsyncUnitTest):
class AsyncTestGridFile(AsyncIntegrationTest):
async def asyncSetUp(self):
await super().asyncSetUp()
await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks)
async def test_basic(self):

View File

@ -16,498 +16,447 @@ from __future__ import annotations
import asyncio
import sys
import threading
import unittest
from pymongo.lock import _async_create_condition, _async_create_lock
sys.path[0:0] = [""]
from pymongo.lock import _ACondition
if sys.version_info < (3, 13):
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
# Includes tests for:
# - https://github.com/python/cpython/issues/111693
# - https://github.com/python/cpython/issues/112202
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
async def test_wait(self):
cond = _async_create_condition(_async_create_lock())
result = []
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
# Includes tests for:
# - https://github.com/python/cpython/issues/111693
# - https://github.com/python/cpython/issues/112202
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
async def test_wait(self):
cond = _ACondition(threading.Condition(threading.Lock()))
result = []
async def c1(result):
await cond.acquire()
if await cond.wait():
result.append(1)
return True
async def c2(result):
await cond.acquire()
if await cond.wait():
result.append(2)
return True
async def c3(result):
await cond.acquire()
if await cond.wait():
result.append(3)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))
await asyncio.sleep(0)
self.assertEqual([], result)
self.assertFalse(cond.locked())
self.assertTrue(await cond.acquire())
cond.notify()
await asyncio.sleep(0)
self.assertEqual([], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(cond.locked())
cond.notify(2)
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1, 2], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1, 2, 3], result)
self.assertTrue(cond.locked())
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
self.assertTrue(t3.done())
self.assertTrue(t3.result())
async def test_wait_cancel(self):
cond = _ACondition(threading.Condition(threading.Lock()))
await cond.acquire()
wait = asyncio.create_task(cond.wait())
asyncio.get_running_loop().call_soon(wait.cancel)
with self.assertRaises(asyncio.CancelledError):
await wait
self.assertFalse(cond._waiters)
self.assertTrue(cond.locked())
async def test_wait_cancel_contested(self):
cond = _ACondition(threading.Condition(threading.Lock()))
await cond.acquire()
self.assertTrue(cond.locked())
wait_task = asyncio.create_task(cond.wait())
await asyncio.sleep(0)
self.assertFalse(cond.locked())
# Notify, but contest the lock before cancelling
await cond.acquire()
self.assertTrue(cond.locked())
cond.notify()
asyncio.get_running_loop().call_soon(wait_task.cancel)
asyncio.get_running_loop().call_soon(cond.release)
try:
await wait_task
except asyncio.CancelledError:
# Should not happen, since no cancellation points
pass
self.assertTrue(cond.locked())
async def test_wait_cancel_after_notify(self):
# See bpo-32841
waited = False
cond = _ACondition(threading.Condition(threading.Lock()))
async def wait_on_cond():
nonlocal waited
async with cond:
waited = True # Make sure this area was reached
await cond.wait()
waiter = asyncio.create_task(wait_on_cond())
await asyncio.sleep(0) # Start waiting
await cond.acquire()
cond.notify()
await asyncio.sleep(0) # Get to acquire()
waiter.cancel()
await asyncio.sleep(0) # Activate cancellation
cond.release()
await asyncio.sleep(0) # Cancellation should occur
self.assertTrue(waiter.cancelled())
self.assertTrue(waited)
async def test_wait_unacquired(self):
cond = _ACondition(threading.Condition(threading.Lock()))
with self.assertRaises(RuntimeError):
await cond.wait()
async def test_wait_for(self):
cond = _ACondition(threading.Condition(threading.Lock()))
presult = False
def predicate():
return presult
result = []
async def c1(result):
await cond.acquire()
if await cond.wait_for(predicate):
result.append(1)
cond.release()
return True
t = asyncio.create_task(c1(result))
await asyncio.sleep(0)
self.assertEqual([], result)
await cond.acquire()
cond.notify()
cond.release()
await asyncio.sleep(0)
self.assertEqual([], result)
presult = True
await cond.acquire()
cond.notify()
cond.release()
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(t.done())
self.assertTrue(t.result())
async def test_wait_for_unacquired(self):
cond = _ACondition(threading.Condition(threading.Lock()))
# predicate can return true immediately
res = await cond.wait_for(lambda: [1, 2, 3])
self.assertEqual([1, 2, 3], res)
with self.assertRaises(RuntimeError):
await cond.wait_for(lambda: False)
async def test_notify(self):
cond = _ACondition(threading.Condition(threading.Lock()))
result = []
async def c1(result):
async with cond:
async def c1(result):
await cond.acquire()
if await cond.wait():
result.append(1)
return True
async def c2(result):
async with cond:
async def c2(result):
await cond.acquire()
if await cond.wait():
result.append(2)
return True
async def c3(result):
async with cond:
async def c3(result):
await cond.acquire()
if await cond.wait():
result.append(3)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))
await asyncio.sleep(0)
self.assertEqual([], result)
async with cond:
cond.notify(1)
await asyncio.sleep(1)
self.assertEqual([1], result)
async with cond:
cond.notify(1)
cond.notify(2048)
await asyncio.sleep(1)
self.assertEqual([1, 2, 3], result)
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
self.assertTrue(t3.done())
self.assertTrue(t3.result())
async def test_notify_all(self):
cond = _ACondition(threading.Condition(threading.Lock()))
result = []
async def c1(result):
async with cond:
if await cond.wait():
result.append(1)
return True
async def c2(result):
async with cond:
if await cond.wait():
result.append(2)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
await asyncio.sleep(0)
self.assertEqual([], result)
async with cond:
cond.notify_all()
await asyncio.sleep(1)
self.assertEqual([1, 2], result)
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
async def test_context_manager(self):
cond = _ACondition(threading.Condition(threading.Lock()))
self.assertFalse(cond.locked())
async with cond:
self.assertTrue(cond.locked())
self.assertFalse(cond.locked())
async def test_timeout_in_block(self):
condition = _ACondition(threading.Condition(threading.Lock()))
async with condition:
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(condition.wait(), timeout=0.5)
@unittest.skipIf(
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
)
async def test_cancelled_error_wakeup(self):
# Test that a cancelled error, received when awaiting wakeup,
# will be re-raised un-modified.
wake = False
raised = None
cond = _ACondition(threading.Condition(threading.Lock()))
async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised
task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition, cancel it there.
task.cancel(msg="foo") # type: ignore[call-arg]
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)
@unittest.skipIf(
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
)
async def test_cancelled_error_re_aquire(self):
# Test that a cancelled error, received when re-aquiring lock,
# will be re-raised un-modified.
wake = False
raised = None
cond = _ACondition(threading.Condition(threading.Lock()))
async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised
task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition
await cond.acquire()
wake = True
cond.notify()
await asyncio.sleep(0)
# Task is now trying to re-acquire the lock, cancel it there.
task.cancel(msg="foo") # type: ignore[call-arg]
cond.release()
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
async def test_cancelled_wakeup(self):
# Test that a task cancelled at the "same" time as it is woken
# up as part of a Condition.notify() does not result in a lost wakeup.
# This test simulates a cancel while the target task is awaiting initial
# wakeup on the wakeup queue.
condition = _ACondition(threading.Condition(threading.Lock()))
state = 0
async def consumer():
nonlocal state
async with condition:
while True:
await condition.wait_for(lambda: state != 0)
if state < 0:
return
state -= 1
# create two consumers
c = [asyncio.create_task(consumer()) for _ in range(2)]
# wait for them to settle
await asyncio.sleep(0.1)
async with condition:
# produce one item and wake up one
state += 1
condition.notify(1)
# Cancel it while it is awaiting to be run.
# This cancellation could come from the outside
c[0].cancel()
# now wait for the item to be consumed
# if it doesn't means that our "notify" didn"t take hold.
# because it raced with a cancel()
try:
async with asyncio.timeout(1):
await condition.wait_for(lambda: state == 0)
except TimeoutError:
pass
self.assertEqual(state, 0)
# clean up
state = -1
condition.notify_all()
await c[1]
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
async def test_cancelled_wakeup_relock(self):
# Test that a task cancelled at the "same" time as it is woken
# up as part of a Condition.notify() does not result in a lost wakeup.
# This test simulates a cancel while the target task is acquiring the lock
# again.
condition = _ACondition(threading.Condition(threading.Lock()))
state = 0
async def consumer():
nonlocal state
async with condition:
while True:
await condition.wait_for(lambda: state != 0)
if state < 0:
return
state -= 1
# create two consumers
c = [asyncio.create_task(consumer()) for _ in range(2)]
# wait for them to settle
await asyncio.sleep(0.1)
async with condition:
# produce one item and wake up one
state += 1
condition.notify(1)
# now we sleep for a bit. This allows the target task to wake up and
# settle on re-aquiring the lock
await asyncio.sleep(0)
self.assertEqual([], result)
self.assertFalse(cond.locked())
# Cancel it while awaiting the lock
# This cancel could come the outside.
c[0].cancel()
self.assertTrue(await cond.acquire())
cond.notify()
await asyncio.sleep(0)
self.assertEqual([], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(cond.locked())
cond.notify(2)
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1, 2], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1, 2, 3], result)
self.assertTrue(cond.locked())
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
self.assertTrue(t3.done())
self.assertTrue(t3.result())
async def test_wait_cancel(self):
cond = _async_create_condition(_async_create_lock())
await cond.acquire()
wait = asyncio.create_task(cond.wait())
asyncio.get_running_loop().call_soon(wait.cancel)
with self.assertRaises(asyncio.CancelledError):
await wait
self.assertFalse(cond._waiters)
self.assertTrue(cond.locked())
async def test_wait_cancel_contested(self):
cond = _async_create_condition(_async_create_lock())
await cond.acquire()
self.assertTrue(cond.locked())
wait_task = asyncio.create_task(cond.wait())
await asyncio.sleep(0)
self.assertFalse(cond.locked())
# Notify, but contest the lock before cancelling
await cond.acquire()
self.assertTrue(cond.locked())
cond.notify()
asyncio.get_running_loop().call_soon(wait_task.cancel)
asyncio.get_running_loop().call_soon(cond.release)
# now wait for the item to be consumed
# if it doesn't means that our "notify" didn"t take hold.
# because it raced with a cancel()
try:
async with asyncio.timeout(1):
await condition.wait_for(lambda: state == 0)
except TimeoutError:
await wait_task
except asyncio.CancelledError:
# Should not happen, since no cancellation points
pass
self.assertEqual(state, 0)
# clean up
state = -1
condition.notify_all()
await c[1]
self.assertTrue(cond.locked())
async def test_wait_cancel_after_notify(self):
# See bpo-32841
waited = False
class TestCondition(unittest.IsolatedAsyncioTestCase):
async def test_multiple_loops_notify(self):
cond = _ACondition(threading.Condition(threading.Lock()))
cond = _async_create_condition(_async_create_lock())
def tmain(cond):
async def atmain(cond):
await asyncio.sleep(1)
async def wait_on_cond():
nonlocal waited
async with cond:
cond.notify(1)
waited = True # Make sure this area was reached
await cond.wait()
asyncio.run(atmain(cond))
waiter = asyncio.create_task(wait_on_cond())
await asyncio.sleep(0) # Start waiting
t = threading.Thread(target=tmain, args=(cond,))
t.start()
await cond.acquire()
cond.notify()
await asyncio.sleep(0) # Get to acquire()
waiter.cancel()
await asyncio.sleep(0) # Activate cancellation
cond.release()
await asyncio.sleep(0) # Cancellation should occur
async with cond:
self.assertTrue(await cond.wait(30))
t.join()
self.assertTrue(waiter.cancelled())
self.assertTrue(waited)
async def test_multiple_loops_notify_all(self):
cond = _ACondition(threading.Condition(threading.Lock()))
results = []
async def test_wait_unacquired(self):
cond = _async_create_condition(_async_create_lock())
with self.assertRaises(RuntimeError):
await cond.wait()
def tmain(cond, results):
async def atmain(cond, results):
await asyncio.sleep(1)
async def test_wait_for(self):
cond = _async_create_condition(_async_create_lock())
presult = False
def predicate():
return presult
result = []
async def c1(result):
await cond.acquire()
if await cond.wait_for(predicate):
result.append(1)
cond.release()
return True
t = asyncio.create_task(c1(result))
await asyncio.sleep(0)
self.assertEqual([], result)
await cond.acquire()
cond.notify()
cond.release()
await asyncio.sleep(0)
self.assertEqual([], result)
presult = True
await cond.acquire()
cond.notify()
cond.release()
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(t.done())
self.assertTrue(t.result())
async def test_wait_for_unacquired(self):
cond = _async_create_condition(_async_create_lock())
# predicate can return true immediately
res = await cond.wait_for(lambda: [1, 2, 3])
self.assertEqual([1, 2, 3], res)
with self.assertRaises(RuntimeError):
await cond.wait_for(lambda: False)
async def test_notify(self):
cond = _async_create_condition(_async_create_lock())
result = []
async def c1(result):
async with cond:
res = await cond.wait(30)
results.append(res)
if await cond.wait():
result.append(1)
return True
asyncio.run(atmain(cond, results))
async def c2(result):
async with cond:
if await cond.wait():
result.append(2)
return True
nthreads = 5
threads = []
for _ in range(nthreads):
threads.append(threading.Thread(target=tmain, args=(cond, results)))
for t in threads:
t.start()
async def c3(result):
async with cond:
if await cond.wait():
result.append(3)
return True
await asyncio.sleep(2)
async with cond:
cond.notify_all()
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))
for t in threads:
t.join()
await asyncio.sleep(0)
self.assertEqual([], result)
self.assertEqual(results, [True] * nthreads)
async with cond:
cond.notify(1)
await asyncio.sleep(1)
self.assertEqual([1], result)
async with cond:
cond.notify(1)
cond.notify(2048)
await asyncio.sleep(1)
self.assertEqual([1, 2, 3], result)
if __name__ == "__main__":
unittest.main()
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
self.assertTrue(t3.done())
self.assertTrue(t3.result())
async def test_notify_all(self):
cond = _async_create_condition(_async_create_lock())
result = []
async def c1(result):
async with cond:
if await cond.wait():
result.append(1)
return True
async def c2(result):
async with cond:
if await cond.wait():
result.append(2)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
await asyncio.sleep(0)
self.assertEqual([], result)
async with cond:
cond.notify_all()
await asyncio.sleep(1)
self.assertEqual([1, 2], result)
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
async def test_context_manager(self):
cond = _async_create_condition(_async_create_lock())
self.assertFalse(cond.locked())
async with cond:
self.assertTrue(cond.locked())
self.assertFalse(cond.locked())
async def test_timeout_in_block(self):
condition = _async_create_condition(_async_create_lock())
async with condition:
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(condition.wait(), timeout=0.5)
@unittest.skipIf(
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
)
async def test_cancelled_error_wakeup(self):
# Test that a cancelled error, received when awaiting wakeup,
# will be re-raised un-modified.
wake = False
raised = None
cond = _async_create_condition(_async_create_lock())
async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised
task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition, cancel it there.
task.cancel(msg="foo") # type: ignore[call-arg]
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)
@unittest.skipIf(
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
)
async def test_cancelled_error_re_aquire(self):
# Test that a cancelled error, received when re-aquiring lock,
# will be re-raised un-modified.
wake = False
raised = None
cond = _async_create_condition(_async_create_lock())
async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised
task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition
await cond.acquire()
wake = True
cond.notify()
await asyncio.sleep(0)
# Task is now trying to re-acquire the lock, cancel it there.
task.cancel(msg="foo") # type: ignore[call-arg]
cond.release()
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
async def test_cancelled_wakeup(self):
# Test that a task cancelled at the "same" time as it is woken
# up as part of a Condition.notify() does not result in a lost wakeup.
# This test simulates a cancel while the target task is awaiting initial
# wakeup on the wakeup queue.
condition = _async_create_condition(_async_create_lock())
state = 0
async def consumer():
nonlocal state
async with condition:
while True:
await condition.wait_for(lambda: state != 0)
if state < 0:
return
state -= 1
# create two consumers
c = [asyncio.create_task(consumer()) for _ in range(2)]
# wait for them to settle
await asyncio.sleep(0.1)
async with condition:
# produce one item and wake up one
state += 1
condition.notify(1)
# Cancel it while it is awaiting to be run.
# This cancellation could come from the outside
c[0].cancel()
# now wait for the item to be consumed
# if it doesn't means that our "notify" didn"t take hold.
# because it raced with a cancel()
try:
async with asyncio.timeout(1):
await condition.wait_for(lambda: state == 0)
except TimeoutError:
pass
self.assertEqual(state, 0)
# clean up
state = -1
condition.notify_all()
await c[1]
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
async def test_cancelled_wakeup_relock(self):
# Test that a task cancelled at the "same" time as it is woken
# up as part of a Condition.notify() does not result in a lost wakeup.
# This test simulates a cancel while the target task is acquiring the lock
# again.
condition = _async_create_condition(_async_create_lock())
state = 0
async def consumer():
nonlocal state
async with condition:
while True:
await condition.wait_for(lambda: state != 0)
if state < 0:
return
state -= 1
# create two consumers
c = [asyncio.create_task(consumer()) for _ in range(2)]
# wait for them to settle
await asyncio.sleep(0.1)
async with condition:
# produce one item and wake up one
state += 1
condition.notify(1)
# now we sleep for a bit. This allows the target task to wake up and
# settle on re-aquiring the lock
await asyncio.sleep(0)
# Cancel it while awaiting the lock
# This cancel could come the outside.
c[0].cancel()
# now wait for the item to be consumed
# if it doesn't means that our "notify" didn"t take hold.
# because it raced with a cancel()
try:
async with asyncio.timeout(1):
await condition.wait_for(lambda: state == 0)
except TimeoutError:
pass
self.assertEqual(state, 0)
# clean up
state = -1
condition.notify_all()
await c[1]
if __name__ == "__main__":
unittest.main()

View File

@ -52,22 +52,16 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest):
listener: EventListener
@classmethod
@async_client_context.require_connection
async def _setup_class(cls):
await super()._setup_class()
def setUpClass(cls) -> None:
cls.listener = OvertCommandListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(
event_listeners=[cls.listener], retryWrites=False
)
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
await super()._tearDown_class()
async def asyncTearDown(self):
@async_client_context.require_connection
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.listener.reset()
await super().asyncTearDown()
self.client = await self.async_rs_or_single_client(
event_listeners=[self.listener], retryWrites=False
)
async def test_started_simple(self):
await self.client.pymongo_test.command("ping")
@ -1140,26 +1134,23 @@ class AsyncTestGlobalListener(AsyncIntegrationTest):
saved_listeners: Any
@classmethod
@async_client_context.require_connection
async def _setup_class(cls):
await super()._setup_class()
def setUpClass(cls) -> None:
cls.listener = OvertCommandListener()
# We plan to call register(), which internally modifies _LISTENERS.
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
monitoring.register(cls.listener)
cls.client = await cls.unmanaged_async_single_client()
# Get one (authenticated) socket in the pool.
await cls.client.pymongo_test.command("ping")
@classmethod
async def _tearDown_class(cls):
monitoring._LISTENERS = cls.saved_listeners
await cls.client.close()
await super()._tearDown_class()
@async_client_context.require_connection
async def asyncSetUp(self):
await super().asyncSetUp()
self.listener.reset()
self.client = await self.async_single_client()
# Get one (authenticated) socket in the pool.
await self.client.pymongo_test.command("ping")
@classmethod
def tearDownClass(cls):
monitoring._LISTENERS = cls.saved_listeners
async def test_simple(self):
await self.client.pymongo_test.command("ping")

View File

@ -132,34 +132,27 @@ class IgnoreDeprecationsTest(AsyncIntegrationTest):
RUN_ON_SERVERLESS = True
deprecation_filter: DeprecationFilter
@classmethod
async def _setup_class(cls):
await super()._setup_class()
cls.deprecation_filter = DeprecationFilter()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.deprecation_filter = DeprecationFilter()
@classmethod
async def _tearDown_class(cls):
cls.deprecation_filter.stop()
await super()._tearDown_class()
async def asyncTearDown(self) -> None:
self.deprecation_filter.stop()
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
knobs: client_knobs
@classmethod
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
cls.client = await cls.unmanaged_async_rs_or_single_client(retryWrites=True)
cls.db = cls.client.pymongo_test
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
self.knobs.enable()
self.client = await self.async_rs_or_single_client(retryWrites=True)
self.db = self.client.pymongo_test
@classmethod
async def _tearDown_class(cls):
cls.knobs.disable()
await cls.client.close()
await super()._tearDown_class()
async def asyncTearDown(self) -> None:
self.knobs.disable()
@async_client_context.require_no_standalone
async def test_actionable_error_message(self):
@ -180,26 +173,18 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
listener: OvertCommandListener
knobs: client_knobs
@classmethod
@async_client_context.require_no_mmap
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
cls.listener = OvertCommandListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(
retryWrites=True, event_listeners=[cls.listener]
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
self.knobs.enable()
self.listener = OvertCommandListener()
self.client = await self.async_rs_or_single_client(
retryWrites=True, event_listeners=[self.listener]
)
cls.db = cls.client.pymongo_test
self.db = self.client.pymongo_test
@classmethod
async def _tearDown_class(cls):
cls.knobs.disable()
await cls.client.close()
await super()._tearDown_class()
async def asyncSetUp(self):
if async_client_context.is_rs and async_client_context.test_commands_enabled:
await self.client.admin.command(
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")])
@ -210,6 +195,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
await self.client.admin.command(
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
)
self.knobs.disable()
async def test_supported_single_statement_no_retry(self):
listener = OvertCommandListener()
@ -438,13 +424,12 @@ class TestWriteConcernError(AsyncIntegrationTest):
RUN_ON_SERVERLESS = True
fail_insert: dict
@classmethod
@async_client_context.require_replica_set
@async_client_context.require_no_mmap
@async_client_context.require_failCommand_fail_point
async def _setup_class(cls):
await super()._setup_class()
cls.fail_insert = {
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.fail_insert = {
"configureFailPoint": "failCommand",
"mode": {"times": 2},
"data": {

View File

@ -38,7 +38,6 @@ from test.utils import (
ExceptionCatchingThread,
OvertCommandListener,
async_wait_until,
wait_until,
)
from bson import DBRef
@ -83,36 +82,27 @@ class TestSession(AsyncIntegrationTest):
client2: AsyncMongoClient
sensitive_commands: Set[str]
@classmethod
@async_client_context.require_sessions
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self):
await super().asyncSetUp()
# Create a second client so we can make sure clients cannot share
# sessions.
cls.client2 = await cls.unmanaged_async_rs_or_single_client()
self.client2 = await self.async_rs_or_single_client()
# Redact no commands, so we can test user-admin commands have "lsid".
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
monitoring._SENSITIVE_COMMANDS.clear()
@classmethod
async def _tearDown_class(cls):
monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands)
await cls.client2.close()
await super()._tearDown_class()
async def asyncSetUp(self):
self.listener = SessionTestListener()
self.session_checker_listener = SessionTestListener()
self.client = await self.async_rs_or_single_client(
event_listeners=[self.listener, self.session_checker_listener]
)
self.addAsyncCleanup(self.client.close)
self.db = self.client.pymongo_test
self.initial_lsids = {s["id"] for s in session_ids(self.client)}
async def asyncTearDown(self):
"""All sessions used in the test must be returned to the pool."""
monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands)
await self.client.drop_database("pymongo_test")
used_lsids = self.initial_lsids.copy()
for event in self.session_checker_listener.started_events:
@ -122,6 +112,8 @@ class TestSession(AsyncIntegrationTest):
current_lsids = {s["id"] for s in session_ids(self.client)}
self.assertLessEqual(used_lsids, current_lsids)
await super().asyncTearDown()
async def _test_ops(self, client, *ops):
listener = client.options.event_listeners[0]
@ -833,18 +825,11 @@ class TestCausalConsistency(AsyncUnitTest):
listener: SessionTestListener
client: AsyncMongoClient
@classmethod
async def _setup_class(cls):
cls.listener = SessionTestListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
@async_client_context.require_sessions
async def asyncSetUp(self):
await super().asyncSetUp()
self.listener = SessionTestListener()
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
@async_client_context.require_no_standalone
async def test_core(self):

View File

@ -26,7 +26,7 @@ sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.utils import (
OvertCommandListener,
wait_until,
async_wait_until,
)
from typing import List
@ -162,7 +162,7 @@ class TestTransactions(AsyncTransactionsBase):
client = await self.async_rs_client(
async_client_context.mongos_seeds(), localThresholdMS=1000
)
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
coll = client.test.test
# Create the collection.
await coll.insert_one({})
@ -191,7 +191,7 @@ class TestTransactions(AsyncTransactionsBase):
client = await self.async_rs_client(
async_client_context.mongos_seeds(), localThresholdMS=1000
)
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
coll = client.test.test
# Create the collection.
await coll.insert_one({})
@ -403,21 +403,12 @@ class PatchSessionTimeout:
class TestTransactionsConvenientAPI(AsyncTransactionsBase):
@classmethod
async def _setup_class(cls):
await super()._setup_class()
cls.mongos_clients = []
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.mongos_clients = []
if async_client_context.supports_transactions():
for address in async_client_context.mongoses:
cls.mongos_clients.append(
await cls.unmanaged_async_single_client("{}:{}".format(*address))
)
@classmethod
async def _tearDown_class(cls):
for client in cls.mongos_clients:
await client.close()
await super()._tearDown_class()
self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address)))
async def _set_fail_point(self, client, command_args):
cmd = {"configureFailPoint": "failCommand"}

View File

@ -50,6 +50,7 @@ from test.unified_format_shared import (
)
from test.utils import (
async_get_pool,
async_wait_until,
camel_to_snake,
camel_to_snake_args,
parse_spec_options,
@ -304,7 +305,6 @@ class EntityMapUtil:
kwargs["h"] = uri
client = await self.test.async_rs_or_single_client(**kwargs)
self[spec["id"]] = client
self.test.addAsyncCleanup(client.close)
return
elif entity_type == "database":
client = self[spec["client"]]
@ -479,33 +479,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
await db.create_collection(coll_name, write_concern=wc, **opts)
@classmethod
async def _setup_class(cls):
# super call creates internal client cls.client
await super()._setup_class()
# process file-level runOnRequirements
run_on_spec = cls.TEST_SPEC.get("runOnRequirements", [])
if not await cls.should_run_on(run_on_spec):
raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied")
# add any special-casing for skipping tests here
if async_client_context.storage_engine == "mmapv1":
if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str(
cls.TEST_PATH
):
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
# Handle mongos_clients for transactions tests.
cls.mongos_clients = []
if (
async_client_context.supports_transactions()
and not async_client_context.load_balancer
and not async_client_context.serverless
):
for address in async_client_context.mongoses:
cls.mongos_clients.append(
await cls.unmanaged_async_single_client("{}:{}".format(*address))
)
def setUpClass(cls) -> None:
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(
heartbeat_frequency=0.1,
@ -516,17 +490,36 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
cls.knobs.enable()
@classmethod
async def _tearDown_class(cls):
def tearDownClass(cls) -> None:
cls.knobs.disable()
for client in cls.mongos_clients:
await client.close()
await super()._tearDown_class()
async def asyncSetUp(self):
# super call creates internal client cls.client
await super().asyncSetUp()
# process file-level runOnRequirements
run_on_spec = self.TEST_SPEC.get("runOnRequirements", [])
if not await self.should_run_on(run_on_spec):
raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied")
# add any special-casing for skipping tests here
if async_client_context.storage_engine == "mmapv1":
if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str(
self.TEST_PATH
):
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
# Handle mongos_clients for transactions tests.
self.mongos_clients = []
if (
async_client_context.supports_transactions()
and not async_client_context.load_balancer
and not async_client_context.serverless
):
for address in async_client_context.mongoses:
self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address)))
# process schemaVersion
# note: we check major schema version during class generation
# note: we do this here because we cannot run assertions in setUpClass
version = Version.from_string(self.TEST_SPEC["schemaVersion"])
self.assertLessEqual(
version,
@ -1036,7 +1029,6 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
)
client = await self.async_single_client("{}:{}".format(*session._pinned_address))
self.addAsyncCleanup(client.close)
await self.__set_fail_point(client=client, command_args=spec["failPoint"])
async def _testOperation_createEntities(self, spec):
@ -1137,13 +1129,13 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
client, event, count = spec["client"], spec["event"], spec["count"]
self.assertEqual(self._event_count(client, event), count, f"expected {count} not {event!r}")
def _testOperation_waitForEvent(self, spec):
async def _testOperation_waitForEvent(self, spec):
"""Run the waitForEvent test operation.
Wait for a number of events to be published, or fail.
"""
client, event, count = spec["client"], spec["event"], spec["count"]
wait_until(
await async_wait_until(
lambda: self._event_count(client, event) >= count,
f"find {count} {event} event(s)",
)

View File

@ -249,30 +249,22 @@ class AsyncSpecRunner(AsyncIntegrationTest):
knobs: client_knobs
listener: EventListener
@classmethod
async def _setup_class(cls):
await super()._setup_class()
cls.mongos_clients = []
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.mongos_clients = []
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
@classmethod
async def _tearDown_class(cls):
cls.knobs.disable()
for client in cls.mongos_clients:
await client.close()
await super()._tearDown_class()
def setUp(self):
super().setUp()
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
self.knobs.enable()
self.targets = {}
self.listener = None # type: ignore
self.pool_listener = None
self.server_listener = None
self.maxDiff = None
async def asyncTearDown(self) -> None:
self.knobs.disable()
async def _set_fail_point(self, client, command_args):
cmd = SON([("configureFailPoint", "failCommand")])
cmd.update(command_args)
@ -700,8 +692,6 @@ class AsyncSpecRunner(AsyncIntegrationTest):
self.listener = listener
self.pool_listener = pool_listener
self.server_listener = server_listener
# Close the client explicitly to avoid having too many threads open.
self.addAsyncCleanup(client.close)
# Create session0 and session1.
sessions = {}

View File

@ -20,7 +20,7 @@ def event_loop_policy():
return asyncio.get_event_loop_policy()
@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(scope="package", autouse=True)
def test_setup_and_teardown():
setup()
yield

View File

@ -107,4 +107,4 @@ Automation
At MongoDB, Inc. we use a continuous integration job that tests each
combination in the matrix. The job starts up Apache, starts a single server
or replica set, and runs ``test_client.py`` with the proper arguments.
See `run-mod-wsgi-tests.sh <https://github.com/mongodb/mongo-python-driver/blob/master/.evergreen/run-mod-wsgi-tests.sh>`_
See `run-mod-wsgi-tests.sh <https://github.com/mongodb/mongo-python-driver/blob/master/.evergreen/scripts/run-mod-wsgi-tests.sh>`_

View File

@ -42,15 +42,11 @@ class BulkTestBase(IntegrationTest):
coll: Collection
coll_w0: Collection
@classmethod
def _setup_class(cls):
super()._setup_class()
cls.coll = cls.db.test
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
def setUp(self):
super().setUp()
self.coll = self.db.test
self.coll.drop()
self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0))
def assertEqualResponse(self, expected, actual):
"""Compare response from bulk.execute() to expected response."""
@ -785,12 +781,8 @@ class TestBulk(BulkTestBase):
class BulkAuthorizationTestBase(BulkTestBase):
@classmethod
@client_context.require_auth
@client_context.require_no_api_version
def _setup_class(cls):
super()._setup_class()
def setUp(self):
super().setUp()
client_context.create_user(self.db.name, "readonly", "pw", ["read"])
@ -935,21 +927,19 @@ class TestBulkWriteConcern(BulkTestBase):
w: Optional[int]
secondary: MongoClient
@classmethod
def _setup_class(cls):
super()._setup_class()
cls.w = client_context.w
cls.secondary = None
if cls.w is not None and cls.w > 1:
def setUp(self):
super().setUp()
self.w = client_context.w
self.secondary = None
if self.w is not None and self.w > 1:
for member in (client_context.hello)["hosts"]:
if member != (client_context.hello)["primary"]:
cls.secondary = cls.unmanaged_single_client(*partition_node(member))
self.secondary = self.single_client(*partition_node(member))
break
@classmethod
def async_tearDownClass(cls):
if cls.secondary:
cls.secondary.close()
def tearDown(self):
if self.secondary:
self.secondary.close()
def cause_wtimeout(self, requests, ordered):
if not client_context.test_commands_enabled:

View File

@ -820,18 +820,16 @@ class ProseSpecTestsMixin:
class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
dbs: list
@classmethod
@client_context.require_version_min(4, 0, 0, -1)
@client_context.require_change_streams
def _setup_class(cls):
super()._setup_class()
cls.dbs = [cls.db, cls.client.pymongo_test_2]
def setUp(self) -> None:
super().setUp()
self.dbs = [self.db, self.client.pymongo_test_2]
@classmethod
def _tearDown_class(cls):
for db in cls.dbs:
cls.client.drop_database(db)
super()._tearDown_class()
def tearDown(self):
for db in self.dbs:
self.client.drop_database(db)
super().tearDown()
def change_stream_with_client(self, client, *args, **kwargs):
return client.watch(*args, **kwargs)
@ -882,11 +880,10 @@ class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
@classmethod
@client_context.require_version_min(4, 0, 0, -1)
@client_context.require_change_streams
def _setup_class(cls):
super()._setup_class()
def setUp(self) -> None:
super().setUp()
def change_stream_with_client(self, client, *args, **kwargs):
return client[self.db.name].watch(*args, **kwargs)
@ -968,12 +965,9 @@ class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin):
@classmethod
@client_context.require_change_streams
def _setup_class(cls):
super()._setup_class()
def setUp(self):
super().setUp()
# Use a new collection for each test.
self.watched_collection().drop()
self.watched_collection().insert_one({})
@ -1111,20 +1105,11 @@ class TestAllLegacyScenarios(IntegrationTest):
RUN_ON_LOAD_BALANCER = True
listener: AllowListEventListener
@classmethod
@client_context.require_connection
def _setup_class(cls):
super()._setup_class()
cls.listener = AllowListEventListener("aggregate", "getMore")
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
@classmethod
def _tearDown_class(cls):
cls.client.close()
super()._tearDown_class()
def setUp(self):
super().setUp()
self.listener = AllowListEventListener("aggregate", "getMore")
self.client = self.rs_or_single_client(event_listeners=[self.listener])
self.listener.reset()
def setUpCluster(self, scenario_dict):

View File

@ -129,13 +129,8 @@ class ClientUnitTest(UnitTest):
client: MongoClient
@classmethod
def _setup_class(cls):
cls.client = cls.unmanaged_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
@classmethod
def _tearDown_class(cls):
cls.client.close()
def setUp(self) -> None:
self.client = self.rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
@ -1039,14 +1034,21 @@ class TestClient(IntegrationTest):
self.assertFalse(client._topology._opened)
# Ensure kill cursors thread has not been started.
kc_thread = client._kill_cursors_executor._thread
self.assertFalse(kc_thread and kc_thread.is_alive())
if _IS_SYNC:
kc_thread = client._kill_cursors_executor._thread
self.assertFalse(kc_thread and kc_thread.is_alive())
else:
kc_task = client._kill_cursors_executor._task
self.assertFalse(kc_task and not kc_task.done())
# Using the client should open topology and start the thread.
client.admin.command("ping")
self.assertTrue(client._topology._opened)
kc_thread = client._kill_cursors_executor._thread
self.assertTrue(kc_thread and kc_thread.is_alive())
if _IS_SYNC:
kc_thread = client._kill_cursors_executor._thread
self.assertTrue(kc_thread and kc_thread.is_alive())
else:
kc_task = client._kill_cursors_executor._task
self.assertTrue(kc_task and not kc_task.done())
def test_close_does_not_open_servers(self):
client = self.rs_client(connect=False)
@ -1241,6 +1243,7 @@ class TestClient(IntegrationTest):
def test_server_selection_timeout(self):
client = MongoClient(serverSelectionTimeoutMS=100, connect=False)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
client.close()
client = MongoClient(serverSelectionTimeoutMS=0, connect=False)
@ -1251,16 +1254,20 @@ class TestClient(IntegrationTest):
self.assertRaises(
ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False
)
client.close()
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
client.close()
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
self.assertAlmostEqual(0, client.options.server_selection_timeout)
client.close()
# Test invalid timeout in URI ignored and set to default.
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout)
client.close()
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout)

View File

@ -25,7 +25,7 @@ _IS_SYNC = True
class TestClientContext(UnitTest):
def test_must_connect(self):
if "PYMONGO_MUST_CONNECT" not in os.environ:
if not os.environ.get("PYMONGO_MUST_CONNECT"):
raise SkipTest("PYMONGO_MUST_CONNECT is not set")
self.assertTrue(
@ -37,7 +37,7 @@ class TestClientContext(UnitTest):
)
def test_serverless(self):
if "TEST_SERVERLESS" not in os.environ:
if not os.environ.get("TEST_SERVERLESS"):
raise SkipTest("TEST_SERVERLESS is not set")
self.assertTrue(
@ -47,7 +47,7 @@ class TestClientContext(UnitTest):
)
def test_enableTestCommands_is_disabled(self):
if "PYMONGO_DISABLE_TEST_COMMANDS" not in os.environ:
if not os.environ.get("PYMONGO_DISABLE_TEST_COMMANDS"):
raise SkipTest("PYMONGO_DISABLE_TEST_COMMANDS is not set")
self.assertFalse(
@ -56,7 +56,7 @@ class TestClientContext(UnitTest):
)
def test_setdefaultencoding_worked(self):
if "SETDEFAULTENCODING" not in os.environ:
if not os.environ.get("SETDEFAULTENCODING"):
raise SkipTest("SETDEFAULTENCODING is not set")
self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"])

View File

@ -97,26 +97,19 @@ class TestCollation(IntegrationTest):
warn_context: Any
collation: Collation
@classmethod
@client_context.require_connection
def _setup_class(cls):
super()._setup_class()
cls.listener = OvertCommandListener()
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
cls.db = cls.client.pymongo_test
cls.collation = Collation("en_US")
cls.warn_context = warnings.catch_warnings()
cls.warn_context.__enter__()
warnings.simplefilter("ignore", DeprecationWarning)
def setUp(self) -> None:
super().setUp()
self.listener = OvertCommandListener()
self.client = self.rs_or_single_client(event_listeners=[self.listener])
self.db = self.client.pymongo_test
self.collation = Collation("en_US")
self.warn_context = warnings.catch_warnings()
self.warn_context.__enter__()
@classmethod
def _tearDown_class(cls):
cls.warn_context.__exit__()
cls.warn_context = None
cls.client.close()
super()._tearDown_class()
def tearDown(self):
def tearDown(self) -> None:
self.warn_context.__exit__()
self.warn_context = None
self.listener.reset()
super().tearDown()

View File

@ -87,14 +87,10 @@ class TestCollectionNoConnect(UnitTest):
db: Database
client: MongoClient
@classmethod
def _setup_class(cls):
cls.client = MongoClient(connect=False)
cls.db = cls.client.pymongo_test
@classmethod
def _tearDown_class(cls):
cls.client.close()
def setUp(self) -> None:
super().setUp()
self.client = self.simple_client(connect=False)
self.db = self.client.pymongo_test
def test_collection(self):
self.assertRaises(TypeError, Collection, self.db, 5)
@ -164,27 +160,14 @@ class TestCollectionNoConnect(UnitTest):
class TestCollection(IntegrationTest):
w: int
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.w = client_context.w # type: ignore
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine]
else:
asyncio.run(cls.async_tearDownClass())
@classmethod
def async_tearDownClass(cls):
cls.db.drop_collection("test_large_limit")
def setUp(self):
self.db.test.drop()
super().setUp()
self.w = client_context.w # type: ignore
def tearDown(self):
self.db.test.drop()
self.db.drop_collection("test_large_limit")
super().tearDown()
@contextlib.contextmanager
def write_concern_collection(self):
@ -1010,7 +993,10 @@ class TestCollection(IntegrationTest):
db.test.insert_one({"y": 1}, bypass_document_validation=True)
db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True)
wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document")
def predicate():
return db_w0.test.find_one({"x": 1})
wait_until(predicate, "find w:0 replaced document")
def test_update_bypass_document_validation(self):
db = self.db

View File

@ -19,7 +19,12 @@ import sys
sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test import (
IntegrationTest,
client_context,
reset_client_context,
unittest,
)
from test.helpers import repl_set_step_down
from test.utils import (
CMAPListener,
@ -39,29 +44,19 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
listener: CMAPListener
coll: Collection
@classmethod
@client_context.require_replica_set
def _setup_class(cls):
super()._setup_class()
cls.listener = CMAPListener()
cls.client = cls.unmanaged_rs_or_single_client(
event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500
def setUp(self):
self.listener = CMAPListener()
self.client = self.rs_or_single_client(
event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500
)
# Ensure connections to all servers in replica set. This is to test
# that the is_writable flag is properly updated for connections that
# survive a replica set election.
ensure_all_connected(cls.client)
cls.listener.reset()
cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority"))
cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority"))
@classmethod
def _tearDown_class(cls):
cls.client.close()
def setUp(self):
ensure_all_connected(self.client)
self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority"))
self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority"))
# Note that all ops use same write-concern as self.db (majority).
self.db.drop_collection("step-down")
self.db.create_collection("step-down")

View File

@ -56,6 +56,9 @@ class TestCreateEntities(IntegrationTest):
self.assertGreater(len(final_entity_map["events1"]), 0)
for event in final_entity_map["events1"]:
self.assertIn("PoolCreatedEvent", event["name"])
if self.scenario_runner.mongos_clients:
for client in self.scenario_runner.mongos_clients:
client.close()
def test_store_all_others_as_entities(self):
self.scenario_runner = UnifiedSpecTestMixinV1()
@ -122,6 +125,9 @@ class TestCreateEntities(IntegrationTest):
self.assertEqual(entity_map["failures"], [])
self.assertEqual(entity_map["successes"], 2)
self.assertEqual(entity_map["iterations"], 5)
if self.scenario_runner.mongos_clients:
for client in self.scenario_runner.mongos_clients:
client.close()
if __name__ == "__main__":

View File

@ -1636,10 +1636,6 @@ class TestRawBatchCursor(IntegrationTest):
class TestRawBatchCommandCursor(IntegrationTest):
@classmethod
def _setup_class(cls):
super()._setup_class()
def test_aggregate_raw(self):
c = self.db.test
c.drop()

View File

@ -633,6 +633,7 @@ class TestTypeRegistry(unittest.TestCase):
class TestCollectionWCustomType(IntegrationTest):
def setUp(self):
super().setUp()
self.db.test.drop()
def tearDown(self):
@ -754,6 +755,7 @@ class TestCollectionWCustomType(IntegrationTest):
class TestGridFileCustomType(IntegrationTest):
def setUp(self):
super().setUp()
self.db.drop_collection("fs.files")
self.db.drop_collection("fs.chunks")
@ -917,11 +919,10 @@ class ChangeStreamsWCustomTypesTestMixin:
class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_change_streams
def setUpClass(cls):
super().setUpClass()
cls.db.test.delete_many({})
def setUp(self):
super().setUp()
self.db.test.delete_many({})
def tearDown(self):
self.input_target.drop()
@ -935,12 +936,11 @@ class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCus
class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_version_min(4, 0, 0)
@client_context.require_change_streams
def setUpClass(cls):
super().setUpClass()
cls.db.test.delete_many({})
def setUp(self):
super().setUp()
self.db.test.delete_many({})
def tearDown(self):
self.input_target.drop()
@ -954,12 +954,11 @@ class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCusto
class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_version_min(4, 0, 0)
@client_context.require_change_streams
def setUpClass(cls):
super().setUpClass()
cls.db.test.delete_many({})
def setUp(self):
super().setUp()
self.db.test.delete_many({})
def tearDown(self):
self.input_target.drop()

View File

@ -709,6 +709,7 @@ class TestDatabase(IntegrationTest):
class TestDatabaseAggregation(IntegrationTest):
def setUp(self):
super().setUp()
self.pipeline: List[Mapping[str, Any]] = [
{"$listLocalSessions": {}},
{"$limit": 1},

Some files were not shown because too many files have changed in this diff Show More