PYTHON-4530 - Move synchronized test code into top-level test directory (#1718)
This commit is contained in:
parent
2d301e2db2
commit
cfa215c185
@ -258,9 +258,6 @@ if [ -z "$GREEN_FRAMEWORK" ]; then
|
|||||||
# Use --capture=tee-sys so pytest prints test output inline:
|
# Use --capture=tee-sys so pytest prints test output inline:
|
||||||
# https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html
|
# https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html
|
||||||
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS
|
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS
|
||||||
if [ -z "$TEST_ARGS" ]; then # TODO: remove this in PYTHON-4528
|
|
||||||
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 test/synchronous/ $TEST_ARGS
|
|
||||||
fi
|
|
||||||
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 test/asynchronous/ $TEST_ARGS
|
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 test/asynchronous/ $TEST_ARGS
|
||||||
else
|
else
|
||||||
python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS
|
python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS
|
||||||
|
|||||||
1
.github/workflows/test-python.yml
vendored
1
.github/workflows/test-python.yml
vendored
@ -206,5 +206,4 @@ jobs:
|
|||||||
which python
|
which python
|
||||||
pip install -e ".[test]"
|
pip install -e ".[test]"
|
||||||
pytest -v
|
pytest -v
|
||||||
pytest -v test/synchronous/
|
|
||||||
pytest -v test/asynchronous/
|
pytest -v test/asynchronous/
|
||||||
|
|||||||
@ -7,7 +7,7 @@ exclude = (?x)(
|
|||||||
| ^test/conftest.py$
|
| ^test/conftest.py$
|
||||||
)
|
)
|
||||||
|
|
||||||
[mypy-pymongo.synchronous.*,gridfs.synchronous.*,test.synchronous.*]
|
[mypy-pymongo.synchronous.*,gridfs.synchronous.*,test.*]
|
||||||
warn_unused_ignores = false
|
warn_unused_ignores = false
|
||||||
disable_error_code = unused-coroutine
|
disable_error_code = unused-coroutine
|
||||||
|
|
||||||
|
|||||||
427
test/__init__.py
427
test/__init__.py
@ -12,9 +12,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Test suite for pymongo, bson, and gridfs."""
|
"""Synchronous test suite for pymongo, bson, and gridfs."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import gc
|
import gc
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
@ -28,6 +29,31 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
|
from asyncio import iscoroutinefunction
|
||||||
|
from test.helpers import (
|
||||||
|
COMPRESSORS,
|
||||||
|
IS_SRV,
|
||||||
|
MONGODB_API_VERSION,
|
||||||
|
MULTI_MONGOS_LB_URI,
|
||||||
|
TEST_LOADBALANCER,
|
||||||
|
TEST_SERVERLESS,
|
||||||
|
TLS_OPTIONS,
|
||||||
|
SystemCertsPatcher,
|
||||||
|
_all_users,
|
||||||
|
_create_user,
|
||||||
|
client_knobs,
|
||||||
|
db_pwd,
|
||||||
|
db_user,
|
||||||
|
global_knobs,
|
||||||
|
host,
|
||||||
|
is_server_resolvable,
|
||||||
|
port,
|
||||||
|
print_running_topology,
|
||||||
|
print_thread_stacks,
|
||||||
|
print_thread_tracebacks,
|
||||||
|
sanitize_cmd,
|
||||||
|
sanitize_reply,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ipaddress
|
import ipaddress
|
||||||
@ -38,210 +64,21 @@ except ImportError:
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from test.version import Version
|
from test.version import Version
|
||||||
from typing import Any, Callable, Dict, Generator, no_type_check
|
from typing import Any, Callable, Dict, Generator
|
||||||
from unittest import SkipTest
|
from unittest import SkipTest
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
import pymongo
|
import pymongo
|
||||||
import pymongo.errors
|
import pymongo.errors
|
||||||
from bson.son import SON
|
from bson.son import SON
|
||||||
from pymongo import common, message
|
|
||||||
from pymongo.common import partition_node
|
from pymongo.common import partition_node
|
||||||
from pymongo.hello import HelloCompat
|
from pymongo.hello import HelloCompat
|
||||||
from pymongo.server_api import ServerApi
|
from pymongo.server_api import ServerApi
|
||||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||||
from pymongo.synchronous.database import Database
|
from pymongo.synchronous.database import Database
|
||||||
from pymongo.synchronous.mongo_client import MongoClient
|
from pymongo.synchronous.mongo_client import MongoClient
|
||||||
from pymongo.uri_parser import parse_uri
|
|
||||||
|
|
||||||
if HAVE_SSL:
|
_IS_SYNC = True
|
||||||
import ssl
|
|
||||||
|
|
||||||
# Enable debug output for uncollectable objects. PyPy does not have set_debug.
|
|
||||||
if hasattr(gc, "set_debug"):
|
|
||||||
gc.set_debug(
|
|
||||||
gc.DEBUG_UNCOLLECTABLE | getattr(gc, "DEBUG_OBJECTS", 0) | getattr(gc, "DEBUG_INSTANCES", 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
# The host and port of a single mongod or mongos, or the seed host
|
|
||||||
# for a replica set.
|
|
||||||
host = os.environ.get("DB_IP", "localhost")
|
|
||||||
port = int(os.environ.get("DB_PORT", 27017))
|
|
||||||
IS_SRV = "mongodb+srv" in host
|
|
||||||
|
|
||||||
db_user = os.environ.get("DB_USER", "user")
|
|
||||||
db_pwd = os.environ.get("DB_PASSWORD", "password")
|
|
||||||
|
|
||||||
CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates")
|
|
||||||
CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem"))
|
|
||||||
CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem"))
|
|
||||||
|
|
||||||
TLS_OPTIONS: Dict = {"tls": True}
|
|
||||||
if CLIENT_PEM:
|
|
||||||
TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM
|
|
||||||
if CA_PEM:
|
|
||||||
TLS_OPTIONS["tlsCAFile"] = CA_PEM
|
|
||||||
|
|
||||||
COMPRESSORS = os.environ.get("COMPRESSORS")
|
|
||||||
MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION")
|
|
||||||
TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER"))
|
|
||||||
TEST_SERVERLESS = bool(os.environ.get("TEST_SERVERLESS"))
|
|
||||||
SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI")
|
|
||||||
MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI")
|
|
||||||
|
|
||||||
if TEST_LOADBALANCER:
|
|
||||||
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
|
|
||||||
host, port = res["nodelist"][0]
|
|
||||||
db_user = res["username"] or db_user
|
|
||||||
db_pwd = res["password"] or db_pwd
|
|
||||||
elif TEST_SERVERLESS:
|
|
||||||
TEST_LOADBALANCER = True
|
|
||||||
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
|
|
||||||
host, port = res["nodelist"][0]
|
|
||||||
db_user = res["username"] or db_user
|
|
||||||
db_pwd = res["password"] or db_pwd
|
|
||||||
TLS_OPTIONS = {"tls": True}
|
|
||||||
# Spec says serverless tests must be run with compression.
|
|
||||||
COMPRESSORS = COMPRESSORS or "zlib"
|
|
||||||
|
|
||||||
|
|
||||||
# Shared KMS data.
|
|
||||||
LOCAL_MASTER_KEY = base64.b64decode(
|
|
||||||
b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ"
|
|
||||||
b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
|
|
||||||
)
|
|
||||||
AWS_CREDS = {
|
|
||||||
"accessKeyId": os.environ.get("FLE_AWS_KEY", ""),
|
|
||||||
"secretAccessKey": os.environ.get("FLE_AWS_SECRET", ""),
|
|
||||||
}
|
|
||||||
AWS_CREDS_2 = {
|
|
||||||
"accessKeyId": os.environ.get("FLE_AWS_KEY2", ""),
|
|
||||||
"secretAccessKey": os.environ.get("FLE_AWS_SECRET2", ""),
|
|
||||||
}
|
|
||||||
AZURE_CREDS = {
|
|
||||||
"tenantId": os.environ.get("FLE_AZURE_TENANTID", ""),
|
|
||||||
"clientId": os.environ.get("FLE_AZURE_CLIENTID", ""),
|
|
||||||
"clientSecret": os.environ.get("FLE_AZURE_CLIENTSECRET", ""),
|
|
||||||
}
|
|
||||||
GCP_CREDS = {
|
|
||||||
"email": os.environ.get("FLE_GCP_EMAIL", ""),
|
|
||||||
"privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""),
|
|
||||||
}
|
|
||||||
KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")}
|
|
||||||
|
|
||||||
# Ensure Evergreen metadata doesn't result in truncation
|
|
||||||
os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000")
|
|
||||||
|
|
||||||
|
|
||||||
def is_server_resolvable():
|
|
||||||
"""Returns True if 'server' is resolvable."""
|
|
||||||
socket_timeout = socket.getdefaulttimeout()
|
|
||||||
socket.setdefaulttimeout(1)
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
socket.gethostbyname("server")
|
|
||||||
return True
|
|
||||||
except OSError:
|
|
||||||
return False
|
|
||||||
finally:
|
|
||||||
socket.setdefaulttimeout(socket_timeout)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
|
|
||||||
cmd = SON([("createUser", user)])
|
|
||||||
# X509 doesn't use a password
|
|
||||||
if pwd:
|
|
||||||
cmd["pwd"] = pwd
|
|
||||||
cmd["roles"] = roles or ["root"]
|
|
||||||
cmd.update(**kwargs)
|
|
||||||
return authdb.command(cmd)
|
|
||||||
|
|
||||||
|
|
||||||
class client_knobs:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
heartbeat_frequency=None,
|
|
||||||
min_heartbeat_interval=None,
|
|
||||||
kill_cursor_frequency=None,
|
|
||||||
events_queue_frequency=None,
|
|
||||||
):
|
|
||||||
self.heartbeat_frequency = heartbeat_frequency
|
|
||||||
self.min_heartbeat_interval = min_heartbeat_interval
|
|
||||||
self.kill_cursor_frequency = kill_cursor_frequency
|
|
||||||
self.events_queue_frequency = events_queue_frequency
|
|
||||||
|
|
||||||
self.old_heartbeat_frequency = None
|
|
||||||
self.old_min_heartbeat_interval = None
|
|
||||||
self.old_kill_cursor_frequency = None
|
|
||||||
self.old_events_queue_frequency = None
|
|
||||||
self._enabled = False
|
|
||||||
self._stack = None
|
|
||||||
|
|
||||||
def enable(self):
|
|
||||||
self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY
|
|
||||||
self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL
|
|
||||||
self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY
|
|
||||||
self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY
|
|
||||||
|
|
||||||
if self.heartbeat_frequency is not None:
|
|
||||||
common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency
|
|
||||||
|
|
||||||
if self.min_heartbeat_interval is not None:
|
|
||||||
common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval
|
|
||||||
|
|
||||||
if self.kill_cursor_frequency is not None:
|
|
||||||
common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency
|
|
||||||
|
|
||||||
if self.events_queue_frequency is not None:
|
|
||||||
common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency
|
|
||||||
self._enabled = True
|
|
||||||
# Store the allocation traceback to catch non-disabled client_knobs.
|
|
||||||
self._stack = "".join(traceback.format_stack())
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.enable()
|
|
||||||
|
|
||||||
@no_type_check
|
|
||||||
def disable(self):
|
|
||||||
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
|
|
||||||
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
|
|
||||||
common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency
|
|
||||||
common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency
|
|
||||||
self._enabled = False
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self.disable()
|
|
||||||
|
|
||||||
def __call__(self, func):
|
|
||||||
def make_wrapper(f):
|
|
||||||
@wraps(f)
|
|
||||||
def wrap(*args, **kwargs):
|
|
||||||
with self:
|
|
||||||
return f(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrap
|
|
||||||
|
|
||||||
return make_wrapper(func)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if self._enabled:
|
|
||||||
msg = (
|
|
||||||
"ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, "
|
|
||||||
"MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, "
|
|
||||||
"EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format(
|
|
||||||
common.HEARTBEAT_FREQUENCY,
|
|
||||||
common.MIN_HEARTBEAT_INTERVAL,
|
|
||||||
common.KILL_CURSOR_FREQUENCY,
|
|
||||||
common.EVENTS_QUEUE_FREQUENCY,
|
|
||||||
self._stack,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.disable()
|
|
||||||
raise Exception(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def _all_users(db):
|
|
||||||
return {u["user"] for u in db.command("usersInfo").get("users", [])}
|
|
||||||
|
|
||||||
|
|
||||||
class ClientContext:
|
class ClientContext:
|
||||||
@ -314,7 +151,8 @@ class ClientContext:
|
|||||||
auth_part = ""
|
auth_part = ""
|
||||||
if client_context.auth_enabled:
|
if client_context.auth_enabled:
|
||||||
auth_part = f"{quote_plus(db_user)}:{quote_plus(db_pwd)}@"
|
auth_part = f"{quote_plus(db_user)}:{quote_plus(db_pwd)}@"
|
||||||
return f"mongodb://{auth_part}{self.pair}/?{opts_part}"
|
pair = self.pair
|
||||||
|
return f"mongodb://{auth_part}{pair}/?{opts_part}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hello(self):
|
def hello(self):
|
||||||
@ -468,7 +306,7 @@ class ClientContext:
|
|||||||
self.test_commands_enabled = True
|
self.test_commands_enabled = True
|
||||||
self.has_ipv6 = self._server_started_with_ipv6()
|
self.has_ipv6 = self._server_started_with_ipv6()
|
||||||
|
|
||||||
self.is_mongos = self.hello.get("msg") == "isdbgrid"
|
self.is_mongos = (self.hello).get("msg") == "isdbgrid"
|
||||||
if self.is_mongos:
|
if self.is_mongos:
|
||||||
address = self.client.address
|
address = self.client.address
|
||||||
self.mongoses.append(address)
|
self.mongoses.append(address)
|
||||||
@ -604,15 +442,33 @@ class ClientContext:
|
|||||||
|
|
||||||
def _require(self, condition, msg, func=None):
|
def _require(self, condition, msg, func=None):
|
||||||
def make_wrapper(f):
|
def make_wrapper(f):
|
||||||
|
if iscoroutinefunction(f):
|
||||||
|
wraps_async = True
|
||||||
|
else:
|
||||||
|
wraps_async = False
|
||||||
|
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def wrap(*args, **kwargs):
|
def wrap(*args, **kwargs):
|
||||||
self.init()
|
self.init()
|
||||||
# Always raise SkipTest if we can't connect to MongoDB
|
# Always raise SkipTest if we can't connect to MongoDB
|
||||||
if not self.connected:
|
if not self.connected:
|
||||||
raise SkipTest(f"Cannot connect to MongoDB on {self.pair}")
|
pair = self.pair
|
||||||
if condition():
|
raise SkipTest(f"Cannot connect to MongoDB on {pair}")
|
||||||
return f(*args, **kwargs)
|
if iscoroutinefunction(condition) and condition():
|
||||||
raise SkipTest(msg)
|
if wraps_async:
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
elif condition():
|
||||||
|
if wraps_async:
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
if "self.pair" in msg:
|
||||||
|
new_msg = msg.replace("self.pair", self.pair)
|
||||||
|
else:
|
||||||
|
new_msg = msg
|
||||||
|
raise SkipTest(new_msg)
|
||||||
|
|
||||||
return wrap
|
return wrap
|
||||||
|
|
||||||
@ -635,7 +491,7 @@ class ClientContext:
|
|||||||
"""Run a test only if we can connect to MongoDB."""
|
"""Run a test only if we can connect to MongoDB."""
|
||||||
return self._require(
|
return self._require(
|
||||||
lambda: True, # _require checks if we're connected
|
lambda: True, # _require checks if we're connected
|
||||||
f"Cannot connect to MongoDB on {self.pair}",
|
"Cannot connect to MongoDB on self.pair",
|
||||||
func=func,
|
func=func,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -643,7 +499,7 @@ class ClientContext:
|
|||||||
"""Run a test only if we are connected to Atlas Data Lake."""
|
"""Run a test only if we are connected to Atlas Data Lake."""
|
||||||
return self._require(
|
return self._require(
|
||||||
lambda: self.is_data_lake,
|
lambda: self.is_data_lake,
|
||||||
f"Not connected to Atlas Data Lake on {self.pair}",
|
"Not connected to Atlas Data Lake on self.pair",
|
||||||
func=func,
|
func=func,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -816,7 +672,7 @@ class ClientContext:
|
|||||||
if "sharded" in topologies and self.is_mongos:
|
if "sharded" in topologies and self.is_mongos:
|
||||||
return True
|
return True
|
||||||
if "sharded-replicaset" in topologies and self.is_mongos:
|
if "sharded-replicaset" in topologies and self.is_mongos:
|
||||||
shards = list(client_context.client.config.shards.find())
|
shards = (client_context.client.config.shards.find()).to_list()
|
||||||
for shard in shards:
|
for shard in shards:
|
||||||
# For a 3-member RS-backed sharded cluster, shard['host']
|
# For a 3-member RS-backed sharded cluster, shard['host']
|
||||||
# will be 'replicaName/ip1:port1,ip2:port2,ip3:port3'
|
# will be 'replicaName/ip1:port1,ip2:port2,ip3:port3'
|
||||||
@ -969,45 +825,17 @@ class ClientContext:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def max_bson_size(self):
|
def max_bson_size(self):
|
||||||
return self.hello["maxBsonObjectSize"]
|
return (self.hello)["maxBsonObjectSize"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_write_batch_size(self):
|
def max_write_batch_size(self):
|
||||||
return self.hello["maxWriteBatchSize"]
|
return (self.hello)["maxWriteBatchSize"]
|
||||||
|
|
||||||
|
|
||||||
# Reusable client context
|
# Reusable client context
|
||||||
client_context = ClientContext()
|
client_context = ClientContext()
|
||||||
|
|
||||||
|
|
||||||
def sanitize_cmd(cmd):
|
|
||||||
cp = cmd.copy()
|
|
||||||
cp.pop("$clusterTime", None)
|
|
||||||
cp.pop("$db", None)
|
|
||||||
cp.pop("$readPreference", None)
|
|
||||||
cp.pop("lsid", None)
|
|
||||||
if MONGODB_API_VERSION:
|
|
||||||
# Stable API parameters
|
|
||||||
cp.pop("apiVersion", None)
|
|
||||||
# OP_MSG encoding may move the payload type one field to the
|
|
||||||
# end of the command. Do the same here.
|
|
||||||
name = next(iter(cp))
|
|
||||||
try:
|
|
||||||
identifier = message._FIELD_MAP[name]
|
|
||||||
docs = cp.pop(identifier)
|
|
||||||
cp[identifier] = docs
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return cp
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_reply(reply):
|
|
||||||
cp = reply.copy()
|
|
||||||
cp.pop("$clusterTime", None)
|
|
||||||
cp.pop("operationTime", None)
|
|
||||||
return cp
|
|
||||||
|
|
||||||
|
|
||||||
class PyMongoTestCase(unittest.TestCase):
|
class PyMongoTestCase(unittest.TestCase):
|
||||||
def assertEqualCommand(self, expected, actual, msg=None):
|
def assertEqualCommand(self, expected, actual, msg=None):
|
||||||
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
||||||
@ -1085,40 +913,23 @@ class PyMongoTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(proc.exitcode, 0)
|
self.assertEqual(proc.exitcode, 0)
|
||||||
|
|
||||||
|
|
||||||
def print_thread_tracebacks() -> None:
|
|
||||||
"""Print all Python thread tracebacks."""
|
|
||||||
for thread_id, frame in sys._current_frames().items():
|
|
||||||
sys.stderr.write(f"\n--- Traceback for thread {thread_id} ---\n")
|
|
||||||
traceback.print_stack(frame, file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
def print_thread_stacks(pid: int) -> None:
|
|
||||||
"""Print all C-level thread stacks for a given process id."""
|
|
||||||
if sys.platform == "darwin":
|
|
||||||
cmd = ["lldb", "--attach-pid", f"{pid}", "--batch", "--one-line", '"thread backtrace all"']
|
|
||||||
else:
|
|
||||||
cmd = ["gdb", f"--pid={pid}", "--batch", '--eval-command="thread apply all bt"']
|
|
||||||
|
|
||||||
try:
|
|
||||||
res = subprocess.run(
|
|
||||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8"
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
sys.stderr.write(f"Could not print C-level thread stacks because {cmd[0]} failed: {exc}")
|
|
||||||
else:
|
|
||||||
sys.stderr.write(res.stdout)
|
|
||||||
|
|
||||||
|
|
||||||
class IntegrationTest(PyMongoTestCase):
|
class IntegrationTest(PyMongoTestCase):
|
||||||
"""Base class for TestCases that need a connection to MongoDB to pass."""
|
"""Async base class for TestCases that need a connection to MongoDB to pass."""
|
||||||
|
|
||||||
client: MongoClient[dict]
|
client: MongoClient[dict]
|
||||||
db: Database
|
db: Database
|
||||||
credentials: Dict[str, str]
|
credentials: Dict[str, str]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@client_context.require_connection
|
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
if _IS_SYNC:
|
||||||
|
cls._setup_class()
|
||||||
|
else:
|
||||||
|
asyncio.run(cls._setup_class())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@client_context.require_connection
|
||||||
|
def _setup_class(cls):
|
||||||
if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
|
if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
|
||||||
raise SkipTest("this test does not support load balancers")
|
raise SkipTest("this test does not support load balancers")
|
||||||
if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
|
if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
|
||||||
@ -1171,10 +982,6 @@ class MockClientTest(unittest.TestCase):
|
|||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
|
||||||
|
|
||||||
# Global knobs to speed up the test suite.
|
|
||||||
global_knobs = client_knobs(events_queue_frequency=0.05)
|
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
client_context.init()
|
client_context.init()
|
||||||
warnings.resetwarnings()
|
warnings.resetwarnings()
|
||||||
@ -1182,56 +989,6 @@ def setup():
|
|||||||
global_knobs.enable()
|
global_knobs.enable()
|
||||||
|
|
||||||
|
|
||||||
def _get_executors(topology):
|
|
||||||
executors = []
|
|
||||||
for server in topology._servers.values():
|
|
||||||
# Some MockMonitor do not have an _executor.
|
|
||||||
if hasattr(server._monitor, "_executor"):
|
|
||||||
executors.append(server._monitor._executor)
|
|
||||||
if hasattr(server._monitor, "_rtt_monitor"):
|
|
||||||
executors.append(server._monitor._rtt_monitor._executor)
|
|
||||||
executors.append(topology._Topology__events_executor)
|
|
||||||
if topology._srv_monitor:
|
|
||||||
executors.append(topology._srv_monitor._executor)
|
|
||||||
|
|
||||||
return [e for e in executors if e is not None]
|
|
||||||
|
|
||||||
|
|
||||||
def print_running_topology(topology):
|
|
||||||
running = [e for e in _get_executors(topology) if not e._stopped]
|
|
||||||
if running:
|
|
||||||
print(
|
|
||||||
"WARNING: found Topology with running threads:\n"
|
|
||||||
f" Threads: {running}\n"
|
|
||||||
f" Topology: {topology}\n"
|
|
||||||
f" Creation traceback:\n{topology._settings._stack}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def print_running_clients():
|
|
||||||
from pymongo.synchronous.topology import Topology
|
|
||||||
|
|
||||||
processed = set()
|
|
||||||
# Avoid false positives on the main test client.
|
|
||||||
# XXX: Can be removed after PYTHON-1634 or PYTHON-1896.
|
|
||||||
c = client_context.client
|
|
||||||
if c:
|
|
||||||
processed.add(c._topology._topology_id)
|
|
||||||
# Call collect to manually cleanup any would-be gc'd clients to avoid
|
|
||||||
# false positives.
|
|
||||||
gc.collect()
|
|
||||||
for obj in gc.get_objects():
|
|
||||||
try:
|
|
||||||
if isinstance(obj, Topology):
|
|
||||||
# Avoid printing the same Topology multiple times.
|
|
||||||
if obj._topology_id in processed:
|
|
||||||
continue
|
|
||||||
print_running_topology(obj)
|
|
||||||
processed.add(obj._topology_id)
|
|
||||||
except ReferenceError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def teardown():
|
def teardown():
|
||||||
global_knobs.disable()
|
global_knobs.disable()
|
||||||
garbage = []
|
garbage = []
|
||||||
@ -1266,31 +1023,25 @@ def test_cases(suite):
|
|||||||
yield from test_cases(suite_or_case)
|
yield from test_cases(suite_or_case)
|
||||||
|
|
||||||
|
|
||||||
# Helper method to workaround https://bugs.python.org/issue21724
|
def print_running_clients():
|
||||||
def clear_warning_registry():
|
from pymongo.synchronous.topology import Topology
|
||||||
"""Clear the __warningregistry__ for all modules."""
|
|
||||||
for _, module in list(sys.modules.items()):
|
|
||||||
if hasattr(module, "__warningregistry__"):
|
|
||||||
module.__warningregistry__ = {} # type:ignore[attr-defined]
|
|
||||||
|
|
||||||
|
processed = set()
|
||||||
class SystemCertsPatcher:
|
# Avoid false positives on the main test client.
|
||||||
def __init__(self, ca_certs):
|
# XXX: Can be removed after PYTHON-1634 or PYTHON-1896.
|
||||||
if (
|
c = client_context.client
|
||||||
ssl.OPENSSL_VERSION.lower().startswith("libressl")
|
if c:
|
||||||
and sys.platform == "darwin"
|
processed.add(c._topology._topology_id)
|
||||||
and not _ssl.IS_PYOPENSSL
|
# Call collect to manually cleanup any would-be gc'd clients to avoid
|
||||||
):
|
# false positives.
|
||||||
raise SkipTest(
|
gc.collect()
|
||||||
"LibreSSL on OSX doesn't support setting CA certificates "
|
for obj in gc.get_objects():
|
||||||
"using SSL_CERT_FILE environment variable."
|
try:
|
||||||
)
|
if isinstance(obj, Topology):
|
||||||
self.original_certs = os.environ.get("SSL_CERT_FILE")
|
# Avoid printing the same Topology multiple times.
|
||||||
# Tell OpenSSL where CA certificates live.
|
if obj._topology_id in processed:
|
||||||
os.environ["SSL_CERT_FILE"] = ca_certs
|
continue
|
||||||
|
print_running_topology(obj)
|
||||||
def disable(self):
|
processed.add(obj._topology_id)
|
||||||
if self.original_certs is None:
|
except ReferenceError:
|
||||||
os.environ.pop("SSL_CERT_FILE")
|
pass
|
||||||
else:
|
|
||||||
os.environ["SSL_CERT_FILE"] = self.original_certs
|
|
||||||
|
|||||||
@ -30,7 +30,7 @@ import traceback
|
|||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
from asyncio import iscoroutinefunction
|
from asyncio import iscoroutinefunction
|
||||||
from test import (
|
from test.helpers import (
|
||||||
COMPRESSORS,
|
COMPRESSORS,
|
||||||
IS_SRV,
|
IS_SRV,
|
||||||
MONGODB_API_VERSION,
|
MONGODB_API_VERSION,
|
||||||
@ -41,13 +41,14 @@ from test import (
|
|||||||
SystemCertsPatcher,
|
SystemCertsPatcher,
|
||||||
_all_users,
|
_all_users,
|
||||||
_create_user,
|
_create_user,
|
||||||
|
client_knobs,
|
||||||
db_pwd,
|
db_pwd,
|
||||||
db_user,
|
db_user,
|
||||||
global_knobs,
|
global_knobs,
|
||||||
host,
|
host,
|
||||||
is_server_resolvable,
|
is_server_resolvable,
|
||||||
port,
|
port,
|
||||||
print_running_clients,
|
print_running_topology,
|
||||||
print_thread_stacks,
|
print_thread_stacks,
|
||||||
print_thread_tracebacks,
|
print_thread_tracebacks,
|
||||||
sanitize_cmd,
|
sanitize_cmd,
|
||||||
@ -113,6 +114,7 @@ class AsyncClientContext:
|
|||||||
self.is_data_lake = False
|
self.is_data_lake = False
|
||||||
self.load_balancer = TEST_LOADBALANCER
|
self.load_balancer = TEST_LOADBALANCER
|
||||||
self.serverless = TEST_SERVERLESS
|
self.serverless = TEST_SERVERLESS
|
||||||
|
self._fips_enabled = None
|
||||||
if self.load_balancer or self.serverless:
|
if self.load_balancer or self.serverless:
|
||||||
self.default_client_options["loadBalanced"] = True
|
self.default_client_options["loadBalanced"] = True
|
||||||
if COMPRESSORS:
|
if COMPRESSORS:
|
||||||
@ -189,8 +191,7 @@ class AsyncClientContext:
|
|||||||
if self.client is not None:
|
if self.client is not None:
|
||||||
# Return early when connected to dataLake as mongohoused does not
|
# Return early when connected to dataLake as mongohoused does not
|
||||||
# support the getCmdLineOpts command and is tested without TLS.
|
# support the getCmdLineOpts command and is tested without TLS.
|
||||||
build_info: Any = await self.client.admin.command("buildInfo")
|
if os.environ.get("TEST_DATA_LAKE"):
|
||||||
if "dataLake" in build_info:
|
|
||||||
self.is_data_lake = True
|
self.is_data_lake = True
|
||||||
self.auth_enabled = True
|
self.auth_enabled = True
|
||||||
self.client = await self._connect(host, port, username=db_user, password=db_pwd)
|
self.client = await self._connect(host, port, username=db_user, password=db_pwd)
|
||||||
@ -363,6 +364,17 @@ class AsyncClientContext:
|
|||||||
# Raised if self.server_status is None.
|
# Raised if self.server_status is None.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fips_enabled(self):
|
||||||
|
if self._fips_enabled is not None:
|
||||||
|
return self._fips_enabled
|
||||||
|
try:
|
||||||
|
subprocess.check_call(["fips-mode-setup", "--is-enabled"])
|
||||||
|
self._fips_enabled = True
|
||||||
|
except (subprocess.SubprocessError, FileNotFoundError):
|
||||||
|
self._fips_enabled = False
|
||||||
|
return self._fips_enabled
|
||||||
|
|
||||||
def check_auth_type(self, auth_type):
|
def check_auth_type(self, auth_type):
|
||||||
auth_mechs = self.server_parameters.get("authenticationMechanisms", [])
|
auth_mechs = self.server_parameters.get("authenticationMechanisms", [])
|
||||||
return auth_type in auth_mechs
|
return auth_type in auth_mechs
|
||||||
@ -528,6 +540,12 @@ class AsyncClientContext:
|
|||||||
lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func
|
lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def require_no_fips(self, func):
|
||||||
|
"""Run a test only if the host does not have FIPS enabled."""
|
||||||
|
return self._require(
|
||||||
|
lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func
|
||||||
|
)
|
||||||
|
|
||||||
def require_no_auth(self, func):
|
def require_no_auth(self, func):
|
||||||
"""Run a test only if the server is running without auth enabled."""
|
"""Run a test only if the server is running without auth enabled."""
|
||||||
return self._require(
|
return self._require(
|
||||||
@ -937,6 +955,35 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
|||||||
self.addCleanup(patcher.disable)
|
self.addCleanup(patcher.disable)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMockClientTest(unittest.TestCase):
|
||||||
|
"""Base class for TestCases that use MockClient.
|
||||||
|
|
||||||
|
This class is *not* an IntegrationTest: if properly written, MockClient
|
||||||
|
tests do not require a running server.
|
||||||
|
|
||||||
|
The class temporarily overrides HEARTBEAT_FREQUENCY to speed up tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# MockClients tests that use replicaSet, directConnection=True, pass
|
||||||
|
# multiple seed addresses, or wait for heartbeat events are incompatible
|
||||||
|
# with loadBalanced=True.
|
||||||
|
@classmethod
|
||||||
|
@async_client_context.require_no_load_balancer
|
||||||
|
def setUpClass(cls):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
|
||||||
|
|
||||||
|
self.client_knobs.enable()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.client_knobs.disable()
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
|
|
||||||
async def async_setup():
|
async def async_setup():
|
||||||
await async_client_context.init()
|
await async_client_context.init()
|
||||||
warnings.resetwarnings()
|
warnings.resetwarnings()
|
||||||
@ -976,3 +1023,27 @@ def test_cases(suite):
|
|||||||
else:
|
else:
|
||||||
# unittest.TestSuite
|
# unittest.TestSuite
|
||||||
yield from test_cases(suite_or_case)
|
yield from test_cases(suite_or_case)
|
||||||
|
|
||||||
|
|
||||||
|
def print_running_clients():
|
||||||
|
from pymongo.asynchronous.topology import Topology
|
||||||
|
|
||||||
|
processed = set()
|
||||||
|
# Avoid false positives on the main test client.
|
||||||
|
# XXX: Can be removed after PYTHON-1634 or PYTHON-1896.
|
||||||
|
c = async_client_context.client
|
||||||
|
if c:
|
||||||
|
processed.add(c._topology._topology_id)
|
||||||
|
# Call collect to manually cleanup any would-be gc'd clients to avoid
|
||||||
|
# false positives.
|
||||||
|
gc.collect()
|
||||||
|
for obj in gc.get_objects():
|
||||||
|
try:
|
||||||
|
if isinstance(obj, Topology):
|
||||||
|
# Avoid printing the same Topology multiple times.
|
||||||
|
if obj._topology_id in processed:
|
||||||
|
continue
|
||||||
|
print_running_topology(obj)
|
||||||
|
processed.add(obj._topology_id)
|
||||||
|
except ReferenceError:
|
||||||
|
pass
|
||||||
|
|||||||
@ -4,6 +4,8 @@ from test import setup, teardown
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
_IS_SYNC = True
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def test_setup_and_teardown():
|
def test_setup_and_teardown():
|
||||||
|
|||||||
367
test/helpers.py
Normal file
367
test/helpers.py
Normal file
@ -0,0 +1,367 @@
|
|||||||
|
# Copyright 2024-present MongoDB, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Shared constants and helper method for pymongo, bson, and gridfs test suites."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import gc
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import unittest
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ipaddress
|
||||||
|
|
||||||
|
HAVE_IPADDRESS = True
|
||||||
|
except ImportError:
|
||||||
|
HAVE_IPADDRESS = False
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from functools import wraps
|
||||||
|
from test.version import Version
|
||||||
|
from typing import Any, Callable, Dict, Generator, no_type_check
|
||||||
|
from unittest import SkipTest
|
||||||
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
|
import pymongo
|
||||||
|
import pymongo.errors
|
||||||
|
from bson.son import SON
|
||||||
|
from pymongo import common, message
|
||||||
|
from pymongo.common import partition_node
|
||||||
|
from pymongo.hello import HelloCompat
|
||||||
|
from pymongo.server_api import ServerApi
|
||||||
|
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||||
|
from pymongo.synchronous.database import Database
|
||||||
|
from pymongo.synchronous.mongo_client import MongoClient
|
||||||
|
from pymongo.uri_parser import parse_uri
|
||||||
|
|
||||||
|
if HAVE_SSL:
|
||||||
|
import ssl
|
||||||
|
|
||||||
|
# Enable debug output for uncollectable objects. PyPy does not have set_debug.
|
||||||
|
if hasattr(gc, "set_debug"):
|
||||||
|
gc.set_debug(
|
||||||
|
gc.DEBUG_UNCOLLECTABLE | getattr(gc, "DEBUG_OBJECTS", 0) | getattr(gc, "DEBUG_INSTANCES", 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# The host and port of a single mongod or mongos, or the seed host
|
||||||
|
# for a replica set.
|
||||||
|
host = os.environ.get("DB_IP", "localhost")
|
||||||
|
port = int(os.environ.get("DB_PORT", 27017))
|
||||||
|
IS_SRV = "mongodb+srv" in host
|
||||||
|
|
||||||
|
db_user = os.environ.get("DB_USER", "user")
|
||||||
|
db_pwd = os.environ.get("DB_PASSWORD", "password")
|
||||||
|
|
||||||
|
CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates")
|
||||||
|
CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem"))
|
||||||
|
CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem"))
|
||||||
|
|
||||||
|
TLS_OPTIONS: Dict = {"tls": True}
|
||||||
|
if CLIENT_PEM:
|
||||||
|
TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM
|
||||||
|
if CA_PEM:
|
||||||
|
TLS_OPTIONS["tlsCAFile"] = CA_PEM
|
||||||
|
|
||||||
|
COMPRESSORS = os.environ.get("COMPRESSORS")
|
||||||
|
MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION")
|
||||||
|
TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER"))
|
||||||
|
TEST_SERVERLESS = bool(os.environ.get("TEST_SERVERLESS"))
|
||||||
|
SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI")
|
||||||
|
MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI")
|
||||||
|
|
||||||
|
if TEST_LOADBALANCER:
|
||||||
|
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
|
||||||
|
host, port = res["nodelist"][0]
|
||||||
|
db_user = res["username"] or db_user
|
||||||
|
db_pwd = res["password"] or db_pwd
|
||||||
|
elif TEST_SERVERLESS:
|
||||||
|
TEST_LOADBALANCER = True
|
||||||
|
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
|
||||||
|
host, port = res["nodelist"][0]
|
||||||
|
db_user = res["username"] or db_user
|
||||||
|
db_pwd = res["password"] or db_pwd
|
||||||
|
TLS_OPTIONS = {"tls": True}
|
||||||
|
# Spec says serverless tests must be run with compression.
|
||||||
|
COMPRESSORS = COMPRESSORS or "zlib"
|
||||||
|
|
||||||
|
|
||||||
|
# Shared KMS data.
|
||||||
|
LOCAL_MASTER_KEY = base64.b64decode(
|
||||||
|
b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ"
|
||||||
|
b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
|
||||||
|
)
|
||||||
|
AWS_CREDS = {
|
||||||
|
"accessKeyId": os.environ.get("FLE_AWS_KEY", ""),
|
||||||
|
"secretAccessKey": os.environ.get("FLE_AWS_SECRET", ""),
|
||||||
|
}
|
||||||
|
AWS_CREDS_2 = {
|
||||||
|
"accessKeyId": os.environ.get("FLE_AWS_KEY2", ""),
|
||||||
|
"secretAccessKey": os.environ.get("FLE_AWS_SECRET2", ""),
|
||||||
|
}
|
||||||
|
AZURE_CREDS = {
|
||||||
|
"tenantId": os.environ.get("FLE_AZURE_TENANTID", ""),
|
||||||
|
"clientId": os.environ.get("FLE_AZURE_CLIENTID", ""),
|
||||||
|
"clientSecret": os.environ.get("FLE_AZURE_CLIENTSECRET", ""),
|
||||||
|
}
|
||||||
|
GCP_CREDS = {
|
||||||
|
"email": os.environ.get("FLE_GCP_EMAIL", ""),
|
||||||
|
"privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""),
|
||||||
|
}
|
||||||
|
KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")}
|
||||||
|
|
||||||
|
# Ensure Evergreen metadata doesn't result in truncation
|
||||||
|
os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000")
|
||||||
|
|
||||||
|
|
||||||
|
def is_server_resolvable():
|
||||||
|
"""Returns True if 'server' is resolvable."""
|
||||||
|
socket_timeout = socket.getdefaulttimeout()
|
||||||
|
socket.setdefaulttimeout(1)
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
socket.gethostbyname("server")
|
||||||
|
return True
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
socket.setdefaulttimeout(socket_timeout)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
|
||||||
|
cmd = SON([("createUser", user)])
|
||||||
|
# X509 doesn't use a password
|
||||||
|
if pwd:
|
||||||
|
cmd["pwd"] = pwd
|
||||||
|
cmd["roles"] = roles or ["root"]
|
||||||
|
cmd.update(**kwargs)
|
||||||
|
return authdb.command(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
class client_knobs:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
heartbeat_frequency=None,
|
||||||
|
min_heartbeat_interval=None,
|
||||||
|
kill_cursor_frequency=None,
|
||||||
|
events_queue_frequency=None,
|
||||||
|
):
|
||||||
|
self.heartbeat_frequency = heartbeat_frequency
|
||||||
|
self.min_heartbeat_interval = min_heartbeat_interval
|
||||||
|
self.kill_cursor_frequency = kill_cursor_frequency
|
||||||
|
self.events_queue_frequency = events_queue_frequency
|
||||||
|
|
||||||
|
self.old_heartbeat_frequency = None
|
||||||
|
self.old_min_heartbeat_interval = None
|
||||||
|
self.old_kill_cursor_frequency = None
|
||||||
|
self.old_events_queue_frequency = None
|
||||||
|
self._enabled = False
|
||||||
|
self._stack = None
|
||||||
|
|
||||||
|
def enable(self):
|
||||||
|
self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY
|
||||||
|
self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL
|
||||||
|
self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY
|
||||||
|
self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY
|
||||||
|
|
||||||
|
if self.heartbeat_frequency is not None:
|
||||||
|
common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency
|
||||||
|
|
||||||
|
if self.min_heartbeat_interval is not None:
|
||||||
|
common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval
|
||||||
|
|
||||||
|
if self.kill_cursor_frequency is not None:
|
||||||
|
common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency
|
||||||
|
|
||||||
|
if self.events_queue_frequency is not None:
|
||||||
|
common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency
|
||||||
|
self._enabled = True
|
||||||
|
# Store the allocation traceback to catch non-disabled client_knobs.
|
||||||
|
self._stack = "".join(traceback.format_stack())
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.enable()
|
||||||
|
|
||||||
|
@no_type_check
|
||||||
|
def disable(self):
|
||||||
|
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
|
||||||
|
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
|
||||||
|
common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency
|
||||||
|
common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency
|
||||||
|
self._enabled = False
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.disable()
|
||||||
|
|
||||||
|
def __call__(self, func):
|
||||||
|
def make_wrapper(f):
|
||||||
|
@wraps(f)
|
||||||
|
def wrap(*args, **kwargs):
|
||||||
|
with self:
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
return make_wrapper(func)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self._enabled:
|
||||||
|
msg = (
|
||||||
|
"ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, "
|
||||||
|
"MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, "
|
||||||
|
"EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format(
|
||||||
|
common.HEARTBEAT_FREQUENCY,
|
||||||
|
common.MIN_HEARTBEAT_INTERVAL,
|
||||||
|
common.KILL_CURSOR_FREQUENCY,
|
||||||
|
common.EVENTS_QUEUE_FREQUENCY,
|
||||||
|
self._stack,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.disable()
|
||||||
|
raise Exception(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _all_users(db):
|
||||||
|
return {u["user"] for u in db.command("usersInfo").get("users", [])}
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_cmd(cmd):
|
||||||
|
cp = cmd.copy()
|
||||||
|
cp.pop("$clusterTime", None)
|
||||||
|
cp.pop("$db", None)
|
||||||
|
cp.pop("$readPreference", None)
|
||||||
|
cp.pop("lsid", None)
|
||||||
|
if MONGODB_API_VERSION:
|
||||||
|
# Stable API parameters
|
||||||
|
cp.pop("apiVersion", None)
|
||||||
|
# OP_MSG encoding may move the payload type one field to the
|
||||||
|
# end of the command. Do the same here.
|
||||||
|
name = next(iter(cp))
|
||||||
|
try:
|
||||||
|
identifier = message._FIELD_MAP[name]
|
||||||
|
docs = cp.pop(identifier)
|
||||||
|
cp[identifier] = docs
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
return cp
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_reply(reply):
|
||||||
|
cp = reply.copy()
|
||||||
|
cp.pop("$clusterTime", None)
|
||||||
|
cp.pop("operationTime", None)
|
||||||
|
return cp
|
||||||
|
|
||||||
|
|
||||||
|
def print_thread_tracebacks() -> None:
|
||||||
|
"""Print all Python thread tracebacks."""
|
||||||
|
for thread_id, frame in sys._current_frames().items():
|
||||||
|
sys.stderr.write(f"\n--- Traceback for thread {thread_id} ---\n")
|
||||||
|
traceback.print_stack(frame, file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def print_thread_stacks(pid: int) -> None:
|
||||||
|
"""Print all C-level thread stacks for a given process id."""
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
cmd = ["lldb", "--attach-pid", f"{pid}", "--batch", "--one-line", '"thread backtrace all"']
|
||||||
|
else:
|
||||||
|
cmd = ["gdb", f"--pid={pid}", "--batch", '--eval-command="thread apply all bt"']
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = subprocess.run(
|
||||||
|
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
sys.stderr.write(f"Could not print C-level thread stacks because {cmd[0]} failed: {exc}")
|
||||||
|
else:
|
||||||
|
sys.stderr.write(res.stdout)
|
||||||
|
|
||||||
|
|
||||||
|
# Global knobs to speed up the test suite.
|
||||||
|
global_knobs = client_knobs(events_queue_frequency=0.05)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_executors(topology):
|
||||||
|
executors = []
|
||||||
|
for server in topology._servers.values():
|
||||||
|
# Some MockMonitor do not have an _executor.
|
||||||
|
if hasattr(server._monitor, "_executor"):
|
||||||
|
executors.append(server._monitor._executor)
|
||||||
|
if hasattr(server._monitor, "_rtt_monitor"):
|
||||||
|
executors.append(server._monitor._rtt_monitor._executor)
|
||||||
|
executors.append(topology._Topology__events_executor)
|
||||||
|
if topology._srv_monitor:
|
||||||
|
executors.append(topology._srv_monitor._executor)
|
||||||
|
|
||||||
|
return [e for e in executors if e is not None]
|
||||||
|
|
||||||
|
|
||||||
|
def print_running_topology(topology):
|
||||||
|
running = [e for e in _get_executors(topology) if not e._stopped]
|
||||||
|
if running:
|
||||||
|
print(
|
||||||
|
"WARNING: found Topology with running threads:\n"
|
||||||
|
f" Threads: {running}\n"
|
||||||
|
f" Topology: {topology}\n"
|
||||||
|
f" Creation traceback:\n{topology._settings._stack}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cases(suite):
|
||||||
|
"""Iterator over all TestCases within a TestSuite."""
|
||||||
|
for suite_or_case in suite._tests:
|
||||||
|
if isinstance(suite_or_case, unittest.TestCase):
|
||||||
|
# unittest.TestCase
|
||||||
|
yield suite_or_case
|
||||||
|
else:
|
||||||
|
# unittest.TestSuite
|
||||||
|
yield from test_cases(suite_or_case)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper method to workaround https://bugs.python.org/issue21724
|
||||||
|
def clear_warning_registry():
|
||||||
|
"""Clear the __warningregistry__ for all modules."""
|
||||||
|
for _, module in list(sys.modules.items()):
|
||||||
|
if hasattr(module, "__warningregistry__"):
|
||||||
|
module.__warningregistry__ = {} # type:ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
class SystemCertsPatcher:
|
||||||
|
def __init__(self, ca_certs):
|
||||||
|
if (
|
||||||
|
ssl.OPENSSL_VERSION.lower().startswith("libressl")
|
||||||
|
and sys.platform == "darwin"
|
||||||
|
and not _ssl.IS_PYOPENSSL
|
||||||
|
):
|
||||||
|
raise SkipTest(
|
||||||
|
"LibreSSL on OSX doesn't support setting CA certificates "
|
||||||
|
"using SSL_CERT_FILE environment variable."
|
||||||
|
)
|
||||||
|
self.original_certs = os.environ.get("SSL_CERT_FILE")
|
||||||
|
# Tell OpenSSL where CA certificates live.
|
||||||
|
os.environ["SSL_CERT_FILE"] = ca_certs
|
||||||
|
|
||||||
|
def disable(self):
|
||||||
|
if self.original_certs is None:
|
||||||
|
os.environ.pop("SSL_CERT_FILE")
|
||||||
|
else:
|
||||||
|
os.environ["SSL_CERT_FILE"] = self.original_certs
|
||||||
@ -1,976 +0,0 @@
|
|||||||
# Copyright 2010-present MongoDB, Inc.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""Asynchronous test suite for pymongo, bson, and gridfs."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import gc
|
|
||||||
import multiprocessing
|
|
||||||
import os
|
|
||||||
import signal
|
|
||||||
import socket
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
import unittest
|
|
||||||
import warnings
|
|
||||||
from asyncio import iscoroutinefunction
|
|
||||||
from test import (
|
|
||||||
COMPRESSORS,
|
|
||||||
IS_SRV,
|
|
||||||
MONGODB_API_VERSION,
|
|
||||||
MULTI_MONGOS_LB_URI,
|
|
||||||
TEST_LOADBALANCER,
|
|
||||||
TEST_SERVERLESS,
|
|
||||||
TLS_OPTIONS,
|
|
||||||
SystemCertsPatcher,
|
|
||||||
_all_users,
|
|
||||||
_create_user,
|
|
||||||
db_pwd,
|
|
||||||
db_user,
|
|
||||||
global_knobs,
|
|
||||||
host,
|
|
||||||
is_server_resolvable,
|
|
||||||
port,
|
|
||||||
print_running_clients,
|
|
||||||
print_thread_stacks,
|
|
||||||
print_thread_tracebacks,
|
|
||||||
sanitize_cmd,
|
|
||||||
sanitize_reply,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import ipaddress
|
|
||||||
|
|
||||||
HAVE_IPADDRESS = True
|
|
||||||
except ImportError:
|
|
||||||
HAVE_IPADDRESS = False
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from functools import wraps
|
|
||||||
from test.version import Version
|
|
||||||
from typing import Any, Callable, Dict, Generator
|
|
||||||
from unittest import SkipTest
|
|
||||||
from urllib.parse import quote_plus
|
|
||||||
|
|
||||||
import pymongo
|
|
||||||
import pymongo.errors
|
|
||||||
from bson.son import SON
|
|
||||||
from pymongo.common import partition_node
|
|
||||||
from pymongo.hello import HelloCompat
|
|
||||||
from pymongo.server_api import ServerApi
|
|
||||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
|
||||||
from pymongo.synchronous.database import Database
|
|
||||||
from pymongo.synchronous.mongo_client import MongoClient
|
|
||||||
|
|
||||||
_IS_SYNC = True
|
|
||||||
|
|
||||||
|
|
||||||
class ClientContext:
|
|
||||||
client: MongoClient
|
|
||||||
|
|
||||||
MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Create a client and grab essential information from the server."""
|
|
||||||
self.connection_attempts = []
|
|
||||||
self.connected = False
|
|
||||||
self.w = None
|
|
||||||
self.nodes = set()
|
|
||||||
self.replica_set_name = None
|
|
||||||
self.cmd_line = None
|
|
||||||
self.server_status = None
|
|
||||||
self.version = Version(-1) # Needs to be comparable with Version
|
|
||||||
self.auth_enabled = False
|
|
||||||
self.test_commands_enabled = False
|
|
||||||
self.server_parameters = {}
|
|
||||||
self._hello = None
|
|
||||||
self.is_mongos = False
|
|
||||||
self.mongoses = []
|
|
||||||
self.is_rs = False
|
|
||||||
self.has_ipv6 = False
|
|
||||||
self.tls = False
|
|
||||||
self.tlsCertificateKeyFile = False
|
|
||||||
self.server_is_resolvable = is_server_resolvable()
|
|
||||||
self.default_client_options: Dict = {}
|
|
||||||
self.sessions_enabled = False
|
|
||||||
self.client = None # type: ignore
|
|
||||||
self.conn_lock = threading.Lock()
|
|
||||||
self.is_data_lake = False
|
|
||||||
self.load_balancer = TEST_LOADBALANCER
|
|
||||||
self.serverless = TEST_SERVERLESS
|
|
||||||
if self.load_balancer or self.serverless:
|
|
||||||
self.default_client_options["loadBalanced"] = True
|
|
||||||
if COMPRESSORS:
|
|
||||||
self.default_client_options["compressors"] = COMPRESSORS
|
|
||||||
if MONGODB_API_VERSION:
|
|
||||||
server_api = ServerApi(MONGODB_API_VERSION)
|
|
||||||
self.default_client_options["server_api"] = server_api
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client_options(self):
|
|
||||||
"""Return the MongoClient options for creating a duplicate client."""
|
|
||||||
opts = client_context.default_client_options.copy()
|
|
||||||
opts["host"] = host
|
|
||||||
opts["port"] = port
|
|
||||||
if client_context.auth_enabled:
|
|
||||||
opts["username"] = db_user
|
|
||||||
opts["password"] = db_pwd
|
|
||||||
if self.replica_set_name:
|
|
||||||
opts["replicaSet"] = self.replica_set_name
|
|
||||||
return opts
|
|
||||||
|
|
||||||
@property
|
|
||||||
def uri(self):
|
|
||||||
"""Return the MongoClient URI for creating a duplicate client."""
|
|
||||||
opts = client_context.default_client_options.copy()
|
|
||||||
opts.pop("server_api", None) # Cannot be set from the URI
|
|
||||||
opts_parts = []
|
|
||||||
for opt, val in opts.items():
|
|
||||||
strval = str(val)
|
|
||||||
if isinstance(val, bool):
|
|
||||||
strval = strval.lower()
|
|
||||||
opts_parts.append(f"{opt}={quote_plus(strval)}")
|
|
||||||
opts_part = "&".join(opts_parts)
|
|
||||||
auth_part = ""
|
|
||||||
if client_context.auth_enabled:
|
|
||||||
auth_part = f"{quote_plus(db_user)}:{quote_plus(db_pwd)}@"
|
|
||||||
pair = self.pair
|
|
||||||
return f"mongodb://{auth_part}{pair}/?{opts_part}"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def hello(self):
|
|
||||||
if not self._hello:
|
|
||||||
if self.serverless or self.load_balancer:
|
|
||||||
self._hello = self.client.admin.command(HelloCompat.CMD)
|
|
||||||
else:
|
|
||||||
self._hello = self.client.admin.command(HelloCompat.LEGACY_CMD)
|
|
||||||
return self._hello
|
|
||||||
|
|
||||||
def _connect(self, host, port, **kwargs):
|
|
||||||
kwargs.update(self.default_client_options)
|
|
||||||
client: MongoClient = pymongo.MongoClient(
|
|
||||||
host, port, serverSelectionTimeoutMS=5000, **kwargs
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
client.admin.command("ping") # Can we connect?
|
|
||||||
except pymongo.errors.OperationFailure as exc:
|
|
||||||
# SERVER-32063
|
|
||||||
self.connection_attempts.append(
|
|
||||||
f"connected client {client!r}, but legacy hello failed: {exc}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.connection_attempts.append(f"successfully connected client {client!r}")
|
|
||||||
# If connected, then return client with default timeout
|
|
||||||
return pymongo.MongoClient(host, port, **kwargs)
|
|
||||||
except pymongo.errors.ConnectionFailure as exc:
|
|
||||||
self.connection_attempts.append(f"failed to connect client {client!r}: {exc}")
|
|
||||||
return None
|
|
||||||
finally:
|
|
||||||
client.close()
|
|
||||||
|
|
||||||
def _init_client(self):
|
|
||||||
self.client = self._connect(host, port)
|
|
||||||
if self.client is not None:
|
|
||||||
# Return early when connected to dataLake as mongohoused does not
|
|
||||||
# support the getCmdLineOpts command and is tested without TLS.
|
|
||||||
build_info: Any = self.client.admin.command("buildInfo")
|
|
||||||
if "dataLake" in build_info:
|
|
||||||
self.is_data_lake = True
|
|
||||||
self.auth_enabled = True
|
|
||||||
self.client = self._connect(host, port, username=db_user, password=db_pwd)
|
|
||||||
self.connected = True
|
|
||||||
return
|
|
||||||
|
|
||||||
if HAVE_SSL and not self.client:
|
|
||||||
# Is MongoDB configured for SSL?
|
|
||||||
self.client = self._connect(host, port, **TLS_OPTIONS)
|
|
||||||
if self.client:
|
|
||||||
self.tls = True
|
|
||||||
self.default_client_options.update(TLS_OPTIONS)
|
|
||||||
self.tlsCertificateKeyFile = True
|
|
||||||
|
|
||||||
if self.client:
|
|
||||||
self.connected = True
|
|
||||||
|
|
||||||
if self.serverless:
|
|
||||||
self.auth_enabled = True
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
self.cmd_line = self.client.admin.command("getCmdLineOpts")
|
|
||||||
except pymongo.errors.OperationFailure as e:
|
|
||||||
assert e.details is not None
|
|
||||||
msg = e.details.get("errmsg", "")
|
|
||||||
if e.code == 13 or "unauthorized" in msg or "login" in msg:
|
|
||||||
# Unauthorized.
|
|
||||||
self.auth_enabled = True
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
self.auth_enabled = self._server_started_with_auth()
|
|
||||||
|
|
||||||
if self.auth_enabled:
|
|
||||||
if not self.serverless and not IS_SRV:
|
|
||||||
# See if db_user already exists.
|
|
||||||
if not self._check_user_provided():
|
|
||||||
_create_user(self.client.admin, db_user, db_pwd)
|
|
||||||
|
|
||||||
self.client = self._connect(
|
|
||||||
host,
|
|
||||||
port,
|
|
||||||
username=db_user,
|
|
||||||
password=db_pwd,
|
|
||||||
replicaSet=self.replica_set_name,
|
|
||||||
**self.default_client_options,
|
|
||||||
)
|
|
||||||
|
|
||||||
# May not have this if OperationFailure was raised earlier.
|
|
||||||
self.cmd_line = self.client.admin.command("getCmdLineOpts")
|
|
||||||
|
|
||||||
if self.serverless:
|
|
||||||
self.server_status = {}
|
|
||||||
else:
|
|
||||||
self.server_status = self.client.admin.command("serverStatus")
|
|
||||||
if self.storage_engine == "mmapv1":
|
|
||||||
# MMAPv1 does not support retryWrites=True.
|
|
||||||
self.default_client_options["retryWrites"] = False
|
|
||||||
|
|
||||||
hello = self.hello
|
|
||||||
self.sessions_enabled = "logicalSessionTimeoutMinutes" in hello
|
|
||||||
|
|
||||||
if "setName" in hello:
|
|
||||||
self.replica_set_name = str(hello["setName"])
|
|
||||||
self.is_rs = True
|
|
||||||
if self.auth_enabled:
|
|
||||||
# It doesn't matter which member we use as the seed here.
|
|
||||||
self.client = pymongo.MongoClient(
|
|
||||||
host,
|
|
||||||
port,
|
|
||||||
username=db_user,
|
|
||||||
password=db_pwd,
|
|
||||||
replicaSet=self.replica_set_name,
|
|
||||||
**self.default_client_options,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.client = pymongo.MongoClient(
|
|
||||||
host, port, replicaSet=self.replica_set_name, **self.default_client_options
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the authoritative hello result from the primary.
|
|
||||||
self._hello = None
|
|
||||||
hello = self.hello
|
|
||||||
nodes = [partition_node(node.lower()) for node in hello.get("hosts", [])]
|
|
||||||
nodes.extend([partition_node(node.lower()) for node in hello.get("passives", [])])
|
|
||||||
nodes.extend([partition_node(node.lower()) for node in hello.get("arbiters", [])])
|
|
||||||
self.nodes = set(nodes)
|
|
||||||
else:
|
|
||||||
self.nodes = {(host, port)}
|
|
||||||
self.w = len(hello.get("hosts", [])) or 1
|
|
||||||
self.version = Version.from_client(self.client)
|
|
||||||
|
|
||||||
if self.serverless:
|
|
||||||
self.server_parameters = {
|
|
||||||
"requireApiVersion": False,
|
|
||||||
"enableTestCommands": True,
|
|
||||||
}
|
|
||||||
self.test_commands_enabled = True
|
|
||||||
self.has_ipv6 = False
|
|
||||||
else:
|
|
||||||
self.server_parameters = self.client.admin.command("getParameter", "*")
|
|
||||||
assert self.cmd_line is not None
|
|
||||||
if self.server_parameters["enableTestCommands"]:
|
|
||||||
self.test_commands_enabled = True
|
|
||||||
elif "parsed" in self.cmd_line:
|
|
||||||
params = self.cmd_line["parsed"].get("setParameter", [])
|
|
||||||
if "enableTestCommands=1" in params:
|
|
||||||
self.test_commands_enabled = True
|
|
||||||
else:
|
|
||||||
params = self.cmd_line["parsed"].get("setParameter", {})
|
|
||||||
if params.get("enableTestCommands") == "1":
|
|
||||||
self.test_commands_enabled = True
|
|
||||||
self.has_ipv6 = self._server_started_with_ipv6()
|
|
||||||
|
|
||||||
self.is_mongos = (self.hello).get("msg") == "isdbgrid"
|
|
||||||
if self.is_mongos:
|
|
||||||
address = self.client.address
|
|
||||||
self.mongoses.append(address)
|
|
||||||
if not self.serverless:
|
|
||||||
# Check for another mongos on the next port.
|
|
||||||
assert address is not None
|
|
||||||
next_address = address[0], address[1] + 1
|
|
||||||
mongos_client = self._connect(*next_address, **self.default_client_options)
|
|
||||||
if mongos_client:
|
|
||||||
hello = mongos_client.admin.command(HelloCompat.LEGACY_CMD)
|
|
||||||
if hello.get("msg") == "isdbgrid":
|
|
||||||
self.mongoses.append(next_address)
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
with self.conn_lock:
|
|
||||||
if not self.client and not self.connection_attempts:
|
|
||||||
self._init_client()
|
|
||||||
|
|
||||||
def connection_attempt_info(self):
|
|
||||||
return "\n".join(self.connection_attempts)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def host(self):
|
|
||||||
if self.is_rs and not IS_SRV:
|
|
||||||
primary = self.client.primary
|
|
||||||
return str(primary[0]) if primary is not None else host
|
|
||||||
return host
|
|
||||||
|
|
||||||
@property
|
|
||||||
def port(self):
|
|
||||||
if self.is_rs and not IS_SRV:
|
|
||||||
primary = self.client.primary
|
|
||||||
return primary[1] if primary is not None else port
|
|
||||||
return port
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pair(self):
|
|
||||||
return "%s:%d" % (self.host, self.port)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def has_secondaries(self):
|
|
||||||
if not self.client:
|
|
||||||
return False
|
|
||||||
return bool(len(self.client.secondaries))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def storage_engine(self):
|
|
||||||
try:
|
|
||||||
return self.server_status.get("storageEngine", {}).get( # type:ignore[union-attr]
|
|
||||||
"name"
|
|
||||||
)
|
|
||||||
except AttributeError:
|
|
||||||
# Raised if self.server_status is None.
|
|
||||||
return None
|
|
||||||
|
|
||||||
def check_auth_type(self, auth_type):
|
|
||||||
auth_mechs = self.server_parameters.get("authenticationMechanisms", [])
|
|
||||||
return auth_type in auth_mechs
|
|
||||||
|
|
||||||
def _check_user_provided(self):
|
|
||||||
"""Return True if db_user/db_password is already an admin user."""
|
|
||||||
client: MongoClient = pymongo.MongoClient(
|
|
||||||
host,
|
|
||||||
port,
|
|
||||||
username=db_user,
|
|
||||||
password=db_pwd,
|
|
||||||
**self.default_client_options,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return db_user in _all_users(client.admin)
|
|
||||||
except pymongo.errors.OperationFailure as e:
|
|
||||||
assert e.details is not None
|
|
||||||
msg = e.details.get("errmsg", "")
|
|
||||||
if e.code == 18 or "auth fails" in msg:
|
|
||||||
# Auth failed.
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
client.close()
|
|
||||||
|
|
||||||
def _server_started_with_auth(self):
|
|
||||||
# MongoDB >= 2.0
|
|
||||||
assert self.cmd_line is not None
|
|
||||||
if "parsed" in self.cmd_line:
|
|
||||||
parsed = self.cmd_line["parsed"]
|
|
||||||
# MongoDB >= 2.6
|
|
||||||
if "security" in parsed:
|
|
||||||
security = parsed["security"]
|
|
||||||
# >= rc3
|
|
||||||
if "authorization" in security:
|
|
||||||
return security["authorization"] == "enabled"
|
|
||||||
# < rc3
|
|
||||||
return security.get("auth", False) or bool(security.get("keyFile"))
|
|
||||||
return parsed.get("auth", False) or bool(parsed.get("keyFile"))
|
|
||||||
# Legacy
|
|
||||||
argv = self.cmd_line["argv"]
|
|
||||||
return "--auth" in argv or "--keyFile" in argv
|
|
||||||
|
|
||||||
def _server_started_with_ipv6(self):
|
|
||||||
if not socket.has_ipv6:
|
|
||||||
return False
|
|
||||||
|
|
||||||
assert self.cmd_line is not None
|
|
||||||
if "parsed" in self.cmd_line:
|
|
||||||
if not self.cmd_line["parsed"].get("net", {}).get("ipv6"):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
if "--ipv6" not in self.cmd_line["argv"]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# The server was started with --ipv6. Is there an IPv6 route to it?
|
|
||||||
try:
|
|
||||||
for info in socket.getaddrinfo(self.host, self.port):
|
|
||||||
if info[0] == socket.AF_INET6:
|
|
||||||
return True
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _require(self, condition, msg, func=None):
|
|
||||||
def make_wrapper(f):
|
|
||||||
if iscoroutinefunction(f):
|
|
||||||
wraps_async = True
|
|
||||||
else:
|
|
||||||
wraps_async = False
|
|
||||||
|
|
||||||
@wraps(f)
|
|
||||||
def wrap(*args, **kwargs):
|
|
||||||
self.init()
|
|
||||||
# Always raise SkipTest if we can't connect to MongoDB
|
|
||||||
if not self.connected:
|
|
||||||
pair = self.pair
|
|
||||||
raise SkipTest(f"Cannot connect to MongoDB on {pair}")
|
|
||||||
if iscoroutinefunction(condition) and condition():
|
|
||||||
if wraps_async:
|
|
||||||
return f(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return f(*args, **kwargs)
|
|
||||||
elif condition():
|
|
||||||
if wraps_async:
|
|
||||||
return f(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return f(*args, **kwargs)
|
|
||||||
if "self.pair" in msg:
|
|
||||||
new_msg = msg.replace("self.pair", self.pair)
|
|
||||||
else:
|
|
||||||
new_msg = msg
|
|
||||||
raise SkipTest(new_msg)
|
|
||||||
|
|
||||||
return wrap
|
|
||||||
|
|
||||||
if func is None:
|
|
||||||
|
|
||||||
def decorate(f):
|
|
||||||
return make_wrapper(f)
|
|
||||||
|
|
||||||
return decorate
|
|
||||||
return make_wrapper(func)
|
|
||||||
|
|
||||||
def create_user(self, dbname, user, pwd=None, roles=None, **kwargs):
|
|
||||||
kwargs["writeConcern"] = {"w": self.w}
|
|
||||||
return _create_user(self.client[dbname], user, pwd, roles, **kwargs)
|
|
||||||
|
|
||||||
def drop_user(self, dbname, user):
|
|
||||||
self.client[dbname].command("dropUser", user, writeConcern={"w": self.w})
|
|
||||||
|
|
||||||
def require_connection(self, func):
|
|
||||||
"""Run a test only if we can connect to MongoDB."""
|
|
||||||
return self._require(
|
|
||||||
lambda: True, # _require checks if we're connected
|
|
||||||
"Cannot connect to MongoDB on self.pair",
|
|
||||||
func=func,
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_data_lake(self, func):
|
|
||||||
"""Run a test only if we are connected to Atlas Data Lake."""
|
|
||||||
return self._require(
|
|
||||||
lambda: self.is_data_lake,
|
|
||||||
"Not connected to Atlas Data Lake on self.pair",
|
|
||||||
func=func,
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_no_mmap(self, func):
|
|
||||||
"""Run a test only if the server is not using the MMAPv1 storage
|
|
||||||
engine. Only works for standalone and replica sets; tests are
|
|
||||||
run regardless of storage engine on sharded clusters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def is_not_mmap():
|
|
||||||
if self.is_mongos:
|
|
||||||
return True
|
|
||||||
return self.storage_engine != "mmapv1"
|
|
||||||
|
|
||||||
return self._require(is_not_mmap, "Storage engine must not be MMAPv1", func=func)
|
|
||||||
|
|
||||||
def require_version_min(self, *ver):
|
|
||||||
"""Run a test only if the server version is at least ``version``."""
|
|
||||||
other_version = Version(*ver)
|
|
||||||
return self._require(
|
|
||||||
lambda: self.version >= other_version,
|
|
||||||
"Server version must be at least %s" % str(other_version),
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_version_max(self, *ver):
|
|
||||||
"""Run a test only if the server version is at most ``version``."""
|
|
||||||
other_version = Version(*ver)
|
|
||||||
return self._require(
|
|
||||||
lambda: self.version <= other_version,
|
|
||||||
"Server version must be at most %s" % str(other_version),
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_auth(self, func):
|
|
||||||
"""Run a test only if the server is running with auth enabled."""
|
|
||||||
return self._require(
|
|
||||||
lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_no_auth(self, func):
|
|
||||||
"""Run a test only if the server is running without auth enabled."""
|
|
||||||
return self._require(
|
|
||||||
lambda: not self.auth_enabled,
|
|
||||||
"Authentication must not be enabled on the server",
|
|
||||||
func=func,
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_replica_set(self, func):
|
|
||||||
"""Run a test only if the client is connected to a replica set."""
|
|
||||||
return self._require(lambda: self.is_rs, "Not connected to a replica set", func=func)
|
|
||||||
|
|
||||||
def require_secondaries_count(self, count):
|
|
||||||
"""Run a test only if the client is connected to a replica set that has
|
|
||||||
`count` secondaries.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def sec_count():
|
|
||||||
return 0 if not self.client else len(self.client.secondaries)
|
|
||||||
|
|
||||||
return self._require(lambda: sec_count() >= count, "Not enough secondaries available")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_secondary_read_pref(self):
|
|
||||||
if self.has_secondaries:
|
|
||||||
return True
|
|
||||||
if self.is_mongos:
|
|
||||||
shard = self.client.config.shards.find_one()["host"] # type:ignore[index]
|
|
||||||
num_members = shard.count(",") + 1
|
|
||||||
return num_members > 1
|
|
||||||
return False
|
|
||||||
|
|
||||||
def require_secondary_read_pref(self):
|
|
||||||
"""Run a test only if the client is connected to a cluster that
|
|
||||||
supports secondary read preference
|
|
||||||
"""
|
|
||||||
return self._require(
|
|
||||||
lambda: self.supports_secondary_read_pref,
|
|
||||||
"This cluster does not support secondary read preference",
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_no_replica_set(self, func):
|
|
||||||
"""Run a test if the client is *not* connected to a replica set."""
|
|
||||||
return self._require(
|
|
||||||
lambda: not self.is_rs, "Connected to a replica set, not a standalone mongod", func=func
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_ipv6(self, func):
|
|
||||||
"""Run a test only if the client can connect to a server via IPv6."""
|
|
||||||
return self._require(lambda: self.has_ipv6, "No IPv6", func=func)
|
|
||||||
|
|
||||||
def require_no_mongos(self, func):
|
|
||||||
"""Run a test only if the client is not connected to a mongos."""
|
|
||||||
return self._require(
|
|
||||||
lambda: not self.is_mongos, "Must be connected to a mongod, not a mongos", func=func
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_mongos(self, func):
|
|
||||||
"""Run a test only if the client is connected to a mongos."""
|
|
||||||
return self._require(lambda: self.is_mongos, "Must be connected to a mongos", func=func)
|
|
||||||
|
|
||||||
def require_multiple_mongoses(self, func):
|
|
||||||
"""Run a test only if the client is connected to a sharded cluster
|
|
||||||
that has 2 mongos nodes.
|
|
||||||
"""
|
|
||||||
return self._require(
|
|
||||||
lambda: len(self.mongoses) > 1, "Must have multiple mongoses available", func=func
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_standalone(self, func):
|
|
||||||
"""Run a test only if the client is connected to a standalone."""
|
|
||||||
return self._require(
|
|
||||||
lambda: not (self.is_mongos or self.is_rs),
|
|
||||||
"Must be connected to a standalone",
|
|
||||||
func=func,
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_no_standalone(self, func):
|
|
||||||
"""Run a test only if the client is not connected to a standalone."""
|
|
||||||
return self._require(
|
|
||||||
lambda: self.is_mongos or self.is_rs,
|
|
||||||
"Must be connected to a replica set or mongos",
|
|
||||||
func=func,
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_load_balancer(self, func):
|
|
||||||
"""Run a test only if the client is connected to a load balancer."""
|
|
||||||
return self._require(
|
|
||||||
lambda: self.load_balancer, "Must be connected to a load balancer", func=func
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_no_load_balancer(self, func):
|
|
||||||
"""Run a test only if the client is not connected to a load balancer."""
|
|
||||||
return self._require(
|
|
||||||
lambda: not self.load_balancer, "Must not be connected to a load balancer", func=func
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_no_serverless(self, func):
|
|
||||||
"""Run a test only if the client is not connected to serverless."""
|
|
||||||
return self._require(
|
|
||||||
lambda: not self.serverless, "Must not be connected to serverless", func=func
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_change_streams(self, func):
|
|
||||||
"""Run a test only if the server supports change streams."""
|
|
||||||
return self.require_no_mmap(self.require_no_standalone(self.require_no_serverless(func)))
|
|
||||||
|
|
||||||
def is_topology_type(self, topologies):
|
|
||||||
unknown = set(topologies) - {
|
|
||||||
"single",
|
|
||||||
"replicaset",
|
|
||||||
"sharded",
|
|
||||||
"sharded-replicaset",
|
|
||||||
"load-balanced",
|
|
||||||
}
|
|
||||||
if unknown:
|
|
||||||
raise AssertionError(f"Unknown topologies: {unknown!r}")
|
|
||||||
if self.load_balancer:
|
|
||||||
if "load-balanced" in topologies:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
if "single" in topologies and not (self.is_mongos or self.is_rs):
|
|
||||||
return True
|
|
||||||
if "replicaset" in topologies and self.is_rs:
|
|
||||||
return True
|
|
||||||
if "sharded" in topologies and self.is_mongos:
|
|
||||||
return True
|
|
||||||
if "sharded-replicaset" in topologies and self.is_mongos:
|
|
||||||
shards = (client_context.client.config.shards.find()).to_list()
|
|
||||||
for shard in shards:
|
|
||||||
# For a 3-member RS-backed sharded cluster, shard['host']
|
|
||||||
# will be 'replicaName/ip1:port1,ip2:port2,ip3:port3'
|
|
||||||
# Otherwise it will be 'ip1:port1'
|
|
||||||
host_spec = shard["host"]
|
|
||||||
if not len(host_spec.split("/")) > 1:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def require_cluster_type(self, topologies=None):
|
|
||||||
"""Run a test only if the client is connected to a cluster that
|
|
||||||
conforms to one of the specified topologies. Acceptable topologies
|
|
||||||
are 'single', 'replicaset', and 'sharded'.
|
|
||||||
"""
|
|
||||||
topologies = topologies or []
|
|
||||||
|
|
||||||
def _is_valid_topology():
|
|
||||||
return self.is_topology_type(topologies)
|
|
||||||
|
|
||||||
return self._require(_is_valid_topology, "Cluster type not in %s" % (topologies))
|
|
||||||
|
|
||||||
def require_test_commands(self, func):
|
|
||||||
"""Run a test only if the server has test commands enabled."""
|
|
||||||
return self._require(
|
|
||||||
lambda: self.test_commands_enabled, "Test commands must be enabled", func=func
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_failCommand_fail_point(self, func):
|
|
||||||
"""Run a test only if the server supports the failCommand fail
|
|
||||||
point.
|
|
||||||
"""
|
|
||||||
return self._require(
|
|
||||||
lambda: self.supports_failCommand_fail_point,
|
|
||||||
"failCommand fail point must be supported",
|
|
||||||
func=func,
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_failCommand_appName(self, func):
|
|
||||||
"""Run a test only if the server supports the failCommand appName."""
|
|
||||||
# SERVER-47195
|
|
||||||
return self._require(
|
|
||||||
lambda: (self.test_commands_enabled and self.version >= (4, 4, -1)),
|
|
||||||
"failCommand appName must be supported",
|
|
||||||
func=func,
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_failCommand_blockConnection(self, func):
|
|
||||||
"""Run a test only if the server supports failCommand blockConnection."""
|
|
||||||
return self._require(
|
|
||||||
lambda: (
|
|
||||||
self.test_commands_enabled
|
|
||||||
and (
|
|
||||||
(not self.is_mongos and self.version >= (4, 2, 9))
|
|
||||||
or (self.is_mongos and self.version >= (4, 4))
|
|
||||||
)
|
|
||||||
),
|
|
||||||
"failCommand blockConnection is not supported",
|
|
||||||
func=func,
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_tls(self, func):
|
|
||||||
"""Run a test only if the client can connect over TLS."""
|
|
||||||
return self._require(lambda: self.tls, "Must be able to connect via TLS", func=func)
|
|
||||||
|
|
||||||
def require_no_tls(self, func):
|
|
||||||
"""Run a test only if the client can connect over TLS."""
|
|
||||||
return self._require(lambda: not self.tls, "Must be able to connect without TLS", func=func)
|
|
||||||
|
|
||||||
def require_tlsCertificateKeyFile(self, func):
|
|
||||||
"""Run a test only if the client can connect with tlsCertificateKeyFile."""
|
|
||||||
return self._require(
|
|
||||||
lambda: self.tlsCertificateKeyFile,
|
|
||||||
"Must be able to connect with tlsCertificateKeyFile",
|
|
||||||
func=func,
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_server_resolvable(self, func):
|
|
||||||
"""Run a test only if the hostname 'server' is resolvable."""
|
|
||||||
return self._require(
|
|
||||||
lambda: self.server_is_resolvable,
|
|
||||||
"No hosts entry for 'server'. Cannot validate hostname in the certificate",
|
|
||||||
func=func,
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_sessions(self, func):
|
|
||||||
"""Run a test only if the deployment supports sessions."""
|
|
||||||
return self._require(lambda: self.sessions_enabled, "Sessions not supported", func=func)
|
|
||||||
|
|
||||||
def supports_retryable_writes(self):
|
|
||||||
if self.storage_engine == "mmapv1":
|
|
||||||
return False
|
|
||||||
if not self.sessions_enabled:
|
|
||||||
return False
|
|
||||||
return self.is_mongos or self.is_rs
|
|
||||||
|
|
||||||
def require_retryable_writes(self, func):
|
|
||||||
"""Run a test only if the deployment supports retryable writes."""
|
|
||||||
return self._require(
|
|
||||||
self.supports_retryable_writes,
|
|
||||||
"This server does not support retryable writes",
|
|
||||||
func=func,
|
|
||||||
)
|
|
||||||
|
|
||||||
def supports_transactions(self):
|
|
||||||
if self.storage_engine == "mmapv1":
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.version.at_least(4, 1, 8):
|
|
||||||
return self.is_mongos or self.is_rs
|
|
||||||
|
|
||||||
if self.version.at_least(4, 0):
|
|
||||||
return self.is_rs
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def require_transactions(self, func):
|
|
||||||
"""Run a test only if the deployment might support transactions.
|
|
||||||
|
|
||||||
*Might* because this does not test the storage engine or FCV.
|
|
||||||
"""
|
|
||||||
return self._require(
|
|
||||||
self.supports_transactions, "Transactions are not supported", func=func
|
|
||||||
)
|
|
||||||
|
|
||||||
def require_no_api_version(self, func):
|
|
||||||
"""Skip this test when testing with requireApiVersion."""
|
|
||||||
return self._require(
|
|
||||||
lambda: not MONGODB_API_VERSION,
|
|
||||||
"This test does not work with requireApiVersion",
|
|
||||||
func=func,
|
|
||||||
)
|
|
||||||
|
|
||||||
def mongos_seeds(self):
|
|
||||||
return ",".join("{}:{}".format(*address) for address in self.mongoses)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_failCommand_fail_point(self):
|
|
||||||
"""Does the server support the failCommand fail point?"""
|
|
||||||
if self.is_mongos:
|
|
||||||
return self.version.at_least(4, 1, 5) and self.test_commands_enabled
|
|
||||||
else:
|
|
||||||
return self.version.at_least(4, 0) and self.test_commands_enabled
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_hint_with_min_max_queries(self):
|
|
||||||
"""Does the server require a hint with min/max queries."""
|
|
||||||
# Changed in SERVER-39567.
|
|
||||||
return self.version.at_least(4, 1, 10)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_bson_size(self):
|
|
||||||
return (self.hello)["maxBsonObjectSize"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_write_batch_size(self):
|
|
||||||
return (self.hello)["maxWriteBatchSize"]
|
|
||||||
|
|
||||||
|
|
||||||
# Reusable client context
|
|
||||||
client_context = ClientContext()
|
|
||||||
|
|
||||||
|
|
||||||
class PyMongoTestCase(unittest.TestCase):
|
|
||||||
def assertEqualCommand(self, expected, actual, msg=None):
|
|
||||||
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
|
||||||
|
|
||||||
def assertEqualReply(self, expected, actual, msg=None):
|
|
||||||
self.assertEqual(sanitize_reply(expected), sanitize_reply(actual), msg)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def fail_point(self, command_args):
|
|
||||||
cmd_on = SON([("configureFailPoint", "failCommand")])
|
|
||||||
cmd_on.update(command_args)
|
|
||||||
client_context.client.admin.command(cmd_on)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
client_context.client.admin.command(
|
|
||||||
"configureFailPoint", cmd_on["configureFailPoint"], mode="off"
|
|
||||||
)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def fork(
|
|
||||||
self, target: Callable, timeout: float = 60
|
|
||||||
) -> Generator[multiprocessing.Process, None, None]:
|
|
||||||
"""Helper for tests that use os.fork()
|
|
||||||
|
|
||||||
Use in a with statement:
|
|
||||||
|
|
||||||
with self.fork(target=lambda: print('in child')) as proc:
|
|
||||||
self.assertTrue(proc.pid) # Child process was started
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _print_threads(*args: object) -> None:
|
|
||||||
if _print_threads.called: # type:ignore[attr-defined]
|
|
||||||
return
|
|
||||||
_print_threads.called = True # type:ignore[attr-defined]
|
|
||||||
print_thread_tracebacks()
|
|
||||||
|
|
||||||
_print_threads.called = False # type:ignore[attr-defined]
|
|
||||||
|
|
||||||
def _target() -> None:
|
|
||||||
signal.signal(signal.SIGUSR1, _print_threads)
|
|
||||||
try:
|
|
||||||
target()
|
|
||||||
except Exception as exc:
|
|
||||||
sys.stderr.write(f"Child process failed with: {exc}\n")
|
|
||||||
_print_threads()
|
|
||||||
# Sleep for a while to let the parent attach via GDB.
|
|
||||||
time.sleep(2 * timeout)
|
|
||||||
raise
|
|
||||||
|
|
||||||
ctx = multiprocessing.get_context("fork")
|
|
||||||
proc = ctx.Process(target=_target)
|
|
||||||
proc.start()
|
|
||||||
try:
|
|
||||||
yield proc # type: ignore
|
|
||||||
finally:
|
|
||||||
proc.join(timeout)
|
|
||||||
pid = proc.pid
|
|
||||||
assert pid
|
|
||||||
if proc.exitcode is None:
|
|
||||||
# gdb to get C-level tracebacks
|
|
||||||
print_thread_stacks(pid)
|
|
||||||
# If it failed, SIGUSR1 to get thread tracebacks.
|
|
||||||
os.kill(pid, signal.SIGUSR1)
|
|
||||||
proc.join(5)
|
|
||||||
if proc.exitcode is None:
|
|
||||||
# SIGINT to get main thread traceback in case SIGUSR1 didn't work.
|
|
||||||
os.kill(pid, signal.SIGINT)
|
|
||||||
proc.join(5)
|
|
||||||
if proc.exitcode is None:
|
|
||||||
# SIGKILL in case SIGINT didn't work.
|
|
||||||
proc.kill()
|
|
||||||
proc.join(1)
|
|
||||||
self.fail(f"child timed out after {timeout}s (see traceback in logs): deadlock?")
|
|
||||||
self.assertEqual(proc.exitcode, 0)
|
|
||||||
|
|
||||||
|
|
||||||
class IntegrationTest(PyMongoTestCase):
|
|
||||||
"""Async base class for TestCases that need a connection to MongoDB to pass."""
|
|
||||||
|
|
||||||
client: MongoClient[dict]
|
|
||||||
db: Database
|
|
||||||
credentials: Dict[str, str]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls._setup_class()
|
|
||||||
else:
|
|
||||||
asyncio.run(cls._setup_class())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@client_context.require_connection
|
|
||||||
def _setup_class(cls):
|
|
||||||
if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
|
|
||||||
raise SkipTest("this test does not support load balancers")
|
|
||||||
if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
|
|
||||||
raise SkipTest("this test does not support serverless")
|
|
||||||
cls.client = client_context.client
|
|
||||||
cls.db = cls.client.pymongo_test
|
|
||||||
if client_context.auth_enabled:
|
|
||||||
cls.credentials = {"username": db_user, "password": db_pwd}
|
|
||||||
else:
|
|
||||||
cls.credentials = {}
|
|
||||||
|
|
||||||
def cleanup_colls(self, *collections):
|
|
||||||
"""Cleanup collections faster than drop_collection."""
|
|
||||||
for c in collections:
|
|
||||||
c = self.client[c.database.name][c.name]
|
|
||||||
c.delete_many({})
|
|
||||||
c.drop_indexes()
|
|
||||||
|
|
||||||
def patch_system_certs(self, ca_certs):
|
|
||||||
patcher = SystemCertsPatcher(ca_certs)
|
|
||||||
self.addCleanup(patcher.disable)
|
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
|
||||||
client_context.init()
|
|
||||||
warnings.resetwarnings()
|
|
||||||
warnings.simplefilter("always")
|
|
||||||
global_knobs.enable()
|
|
||||||
|
|
||||||
|
|
||||||
def teardown():
|
|
||||||
global_knobs.disable()
|
|
||||||
garbage = []
|
|
||||||
for g in gc.garbage:
|
|
||||||
garbage.append(f"GARBAGE: {g!r}")
|
|
||||||
garbage.append(f" gc.get_referents: {gc.get_referents(g)!r}")
|
|
||||||
garbage.append(f" gc.get_referrers: {gc.get_referrers(g)!r}")
|
|
||||||
if garbage:
|
|
||||||
raise AssertionError("\n".join(garbage))
|
|
||||||
c = client_context.client
|
|
||||||
if c:
|
|
||||||
if not client_context.is_data_lake:
|
|
||||||
c.drop_database("pymongo-pooling-tests")
|
|
||||||
c.drop_database("pymongo_test")
|
|
||||||
c.drop_database("pymongo_test1")
|
|
||||||
c.drop_database("pymongo_test2")
|
|
||||||
c.drop_database("pymongo_test_mike")
|
|
||||||
c.drop_database("pymongo_test_bernie")
|
|
||||||
c.close()
|
|
||||||
|
|
||||||
print_running_clients()
|
|
||||||
|
|
||||||
|
|
||||||
def test_cases(suite):
|
|
||||||
"""Iterator over all TestCases within a TestSuite."""
|
|
||||||
for suite_or_case in suite._tests:
|
|
||||||
if isinstance(suite_or_case, unittest.TestCase):
|
|
||||||
# unittest.TestCase
|
|
||||||
yield suite_or_case
|
|
||||||
else:
|
|
||||||
# unittest.TestSuite
|
|
||||||
yield from test_cases(suite_or_case)
|
|
||||||
@ -1,14 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from test import setup, teardown
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
_IS_SYNC = True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def test_setup_and_teardown():
|
|
||||||
setup()
|
|
||||||
yield
|
|
||||||
teardown()
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -36,6 +36,12 @@ from pymongo.synchronous.collection import Collection
|
|||||||
sys.path[0:0] = [""]
|
sys.path[0:0] = [""]
|
||||||
|
|
||||||
from test import (
|
from test import (
|
||||||
|
IntegrationTest,
|
||||||
|
PyMongoTestCase,
|
||||||
|
client_context,
|
||||||
|
unittest,
|
||||||
|
)
|
||||||
|
from test.helpers import (
|
||||||
AWS_CREDS,
|
AWS_CREDS,
|
||||||
AZURE_CREDS,
|
AZURE_CREDS,
|
||||||
CA_PEM,
|
CA_PEM,
|
||||||
@ -43,10 +49,6 @@ from test import (
|
|||||||
GCP_CREDS,
|
GCP_CREDS,
|
||||||
KMIP_CREDS,
|
KMIP_CREDS,
|
||||||
LOCAL_MASTER_KEY,
|
LOCAL_MASTER_KEY,
|
||||||
IntegrationTest,
|
|
||||||
PyMongoTestCase,
|
|
||||||
client_context,
|
|
||||||
unittest,
|
|
||||||
)
|
)
|
||||||
from test.test_bulk import BulkTestBase
|
from test.test_bulk import BulkTestBase
|
||||||
from test.unified_format import generate_test_classes
|
from test.unified_format import generate_test_classes
|
||||||
|
|||||||
@ -24,7 +24,8 @@ import warnings
|
|||||||
|
|
||||||
sys.path[0:0] = [""]
|
sys.path[0:0] = [""]
|
||||||
|
|
||||||
from test import clear_warning_registry, unittest
|
from test import unittest
|
||||||
|
from test.helpers import clear_warning_registry
|
||||||
|
|
||||||
from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate
|
from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate
|
||||||
from pymongo.compression_support import _have_snappy
|
from pymongo.compression_support import _have_snappy
|
||||||
|
|||||||
@ -31,6 +31,11 @@ import traceback
|
|||||||
import types
|
import types
|
||||||
from collections import abc, defaultdict
|
from collections import abc, defaultdict
|
||||||
from test import (
|
from test import (
|
||||||
|
IntegrationTest,
|
||||||
|
client_context,
|
||||||
|
unittest,
|
||||||
|
)
|
||||||
|
from test.helpers import (
|
||||||
AWS_CREDS,
|
AWS_CREDS,
|
||||||
AWS_CREDS_2,
|
AWS_CREDS_2,
|
||||||
AZURE_CREDS,
|
AZURE_CREDS,
|
||||||
@ -39,9 +44,6 @@ from test import (
|
|||||||
GCP_CREDS,
|
GCP_CREDS,
|
||||||
KMIP_CREDS,
|
KMIP_CREDS,
|
||||||
LOCAL_MASTER_KEY,
|
LOCAL_MASTER_KEY,
|
||||||
IntegrationTest,
|
|
||||||
client_context,
|
|
||||||
unittest,
|
|
||||||
)
|
)
|
||||||
from test.utils import (
|
from test.utils import (
|
||||||
CMAPListener,
|
CMAPListener,
|
||||||
|
|||||||
@ -40,6 +40,7 @@ replacements = {
|
|||||||
"async_receive_message": "receive_message",
|
"async_receive_message": "receive_message",
|
||||||
"async_sendall": "sendall",
|
"async_sendall": "sendall",
|
||||||
"asynchronous": "synchronous",
|
"asynchronous": "synchronous",
|
||||||
|
"Asynchronous": "Synchronous",
|
||||||
"anext": "next",
|
"anext": "next",
|
||||||
"_ALock": "_Lock",
|
"_ALock": "_Lock",
|
||||||
"_ACondition": "_Condition",
|
"_ACondition": "_Condition",
|
||||||
@ -60,6 +61,7 @@ replacements = {
|
|||||||
"AsyncTestCollection": "TestCollection",
|
"AsyncTestCollection": "TestCollection",
|
||||||
"AsyncIntegrationTest": "IntegrationTest",
|
"AsyncIntegrationTest": "IntegrationTest",
|
||||||
"AsyncPyMongoTestCase": "PyMongoTestCase",
|
"AsyncPyMongoTestCase": "PyMongoTestCase",
|
||||||
|
"AsyncMockClientTest": "MockClientTest",
|
||||||
"async_client_context": "client_context",
|
"async_client_context": "client_context",
|
||||||
"async_setup": "setup",
|
"async_setup": "setup",
|
||||||
"asyncSetUp": "setUp",
|
"asyncSetUp": "setUp",
|
||||||
@ -100,7 +102,7 @@ _test_base = "./test/asynchronous/"
|
|||||||
|
|
||||||
_pymongo_dest_base = "./pymongo/synchronous/"
|
_pymongo_dest_base = "./pymongo/synchronous/"
|
||||||
_gridfs_dest_base = "./gridfs/synchronous/"
|
_gridfs_dest_base = "./gridfs/synchronous/"
|
||||||
_test_dest_base = "./test/synchronous/"
|
_test_dest_base = "./test/"
|
||||||
|
|
||||||
|
|
||||||
async_files = [
|
async_files = [
|
||||||
@ -125,8 +127,15 @@ sync_gridfs_files = [
|
|||||||
if (Path(_gridfs_dest_base) / f).is_file()
|
if (Path(_gridfs_dest_base) / f).is_file()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Add each asynchronized test here as part of the converting PR
|
||||||
|
converted_tests = [
|
||||||
|
"__init__.py",
|
||||||
|
"conftest.py",
|
||||||
|
"test_collection.py",
|
||||||
|
]
|
||||||
|
|
||||||
sync_test_files = [
|
sync_test_files = [
|
||||||
_test_dest_base + f for f in listdir(_test_dest_base) if (Path(_test_dest_base) / f).is_file()
|
_test_dest_base + f for f in converted_tests if (Path(_test_dest_base) / f).is_file()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -245,7 +254,7 @@ def translate_docstrings(lines: list[str]) -> list[str]:
|
|||||||
if "An asynchronous" in lines[i]:
|
if "An asynchronous" in lines[i]:
|
||||||
lines[i] = lines[i].replace("An asynchronous", "A")
|
lines[i] = lines[i].replace("An asynchronous", "A")
|
||||||
lines[i] = lines[i].replace(k, replacements[k])
|
lines[i] = lines[i].replace(k, replacements[k])
|
||||||
if "Sync" in lines[i] and replacements[k] in lines[i]:
|
if "Sync" in lines[i] and "Synchronous" not in lines[i] and replacements[k] in lines[i]:
|
||||||
lines[i] = lines[i].replace("Sync", "")
|
lines[i] = lines[i].replace("Sync", "")
|
||||||
for i in range(len(lines)):
|
for i in range(len(lines)):
|
||||||
for k in docstring_replacements: # type: ignore[assignment]
|
for k in docstring_replacements: # type: ignore[assignment]
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
#!/bin/bash -eu
|
#!/bin/bash -eu
|
||||||
|
|
||||||
python ./tools/synchro.py
|
python ./tools/synchro.py
|
||||||
python -m ruff check pymongo/synchronous/ gridfs/synchronous/ test/synchronous --fix --silent
|
python -m ruff check pymongo/synchronous/ gridfs/synchronous/ test/ --fix --silent
|
||||||
python -m ruff format pymongo/synchronous/ gridfs/synchronous/ test/synchronous --silent
|
python -m ruff format pymongo/synchronous/ gridfs/synchronous/ test/ --silent
|
||||||
|
|||||||
1
tox.ini
1
tox.ini
@ -66,7 +66,6 @@ extras =
|
|||||||
test
|
test
|
||||||
commands =
|
commands =
|
||||||
pytest -v --durations=5 --maxfail=10 {posargs}
|
pytest -v --durations=5 --maxfail=10 {posargs}
|
||||||
pytest -v --durations=5 --maxfail=10 test/synchronous/ {posargs}
|
|
||||||
|
|
||||||
[testenv:test-async]
|
[testenv:test-async]
|
||||||
description = run base set of async unit tests with no extra functionality
|
description = run base set of async unit tests with no extra functionality
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user