mongo-python-driver/test/helpers.py

358 lines
12 KiB
Python

# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
from __future__ import annotations
import base64
import gc
import multiprocessing
import os
import signal
import socket
import subprocess
import sys
import threading
import time
import traceback
import unittest
import warnings
try:
import ipaddress
HAVE_IPADDRESS = True
except ImportError:
HAVE_IPADDRESS = False
from functools import wraps
from typing import Any, Callable, Dict, Generator, no_type_check
from unittest import SkipTest
from bson.son import SON
from pymongo import common, message
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
from pymongo.uri_parser import parse_uri
if HAVE_SSL:
import ssl
# Enable debug output for uncollectable objects. PyPy does not have set_debug.
if hasattr(gc, "set_debug"):
gc.set_debug(
gc.DEBUG_UNCOLLECTABLE | getattr(gc, "DEBUG_OBJECTS", 0) | getattr(gc, "DEBUG_INSTANCES", 0)
)
# The host and port of a single mongod or mongos, or the seed host
# for a replica set.
host = os.environ.get("DB_IP", "localhost")
port = int(os.environ.get("DB_PORT", 27017))
IS_SRV = "mongodb+srv" in host
db_user = os.environ.get("DB_USER", "user")
db_pwd = os.environ.get("DB_PASSWORD", "password")
CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates")
CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem"))
CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem"))
TLS_OPTIONS: Dict = {"tls": True}
if CLIENT_PEM:
TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM
if CA_PEM:
TLS_OPTIONS["tlsCAFile"] = CA_PEM
COMPRESSORS = os.environ.get("COMPRESSORS")
MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION")
TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER"))
TEST_SERVERLESS = bool(os.environ.get("TEST_SERVERLESS"))
SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI")
MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI")
if TEST_LOADBALANCER:
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
host, port = res["nodelist"][0]
db_user = res["username"] or db_user
db_pwd = res["password"] or db_pwd
elif TEST_SERVERLESS:
TEST_LOADBALANCER = True
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
host, port = res["nodelist"][0]
db_user = res["username"] or db_user
db_pwd = res["password"] or db_pwd
TLS_OPTIONS = {"tls": True}
# Spec says serverless tests must be run with compression.
COMPRESSORS = COMPRESSORS or "zlib"
# Shared KMS data.
LOCAL_MASTER_KEY = base64.b64decode(
b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ"
b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
)
AWS_CREDS = {
"accessKeyId": os.environ.get("FLE_AWS_KEY", ""),
"secretAccessKey": os.environ.get("FLE_AWS_SECRET", ""),
}
AWS_CREDS_2 = {
"accessKeyId": os.environ.get("FLE_AWS_KEY2", ""),
"secretAccessKey": os.environ.get("FLE_AWS_SECRET2", ""),
}
AZURE_CREDS = {
"tenantId": os.environ.get("FLE_AZURE_TENANTID", ""),
"clientId": os.environ.get("FLE_AZURE_CLIENTID", ""),
"clientSecret": os.environ.get("FLE_AZURE_CLIENTSECRET", ""),
}
GCP_CREDS = {
"email": os.environ.get("FLE_GCP_EMAIL", ""),
"privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""),
}
KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")}
# Ensure Evergreen metadata doesn't result in truncation
os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000")
def is_server_resolvable():
"""Returns True if 'server' is resolvable."""
socket_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(1)
try:
try:
socket.gethostbyname("server")
return True
except OSError:
return False
finally:
socket.setdefaulttimeout(socket_timeout)
def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
cmd = SON([("createUser", user)])
# X509 doesn't use a password
if pwd:
cmd["pwd"] = pwd
cmd["roles"] = roles or ["root"]
cmd.update(**kwargs)
return authdb.command(cmd)
class client_knobs:
def __init__(
self,
heartbeat_frequency=None,
min_heartbeat_interval=None,
kill_cursor_frequency=None,
events_queue_frequency=None,
):
self.heartbeat_frequency = heartbeat_frequency
self.min_heartbeat_interval = min_heartbeat_interval
self.kill_cursor_frequency = kill_cursor_frequency
self.events_queue_frequency = events_queue_frequency
self.old_heartbeat_frequency = None
self.old_min_heartbeat_interval = None
self.old_kill_cursor_frequency = None
self.old_events_queue_frequency = None
self._enabled = False
self._stack = None
def enable(self):
self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY
self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL
self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY
self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY
if self.heartbeat_frequency is not None:
common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency
if self.min_heartbeat_interval is not None:
common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval
if self.kill_cursor_frequency is not None:
common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency
if self.events_queue_frequency is not None:
common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency
self._enabled = True
# Store the allocation traceback to catch non-disabled client_knobs.
self._stack = "".join(traceback.format_stack())
def __enter__(self):
self.enable()
@no_type_check
def disable(self):
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency
common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency
self._enabled = False
def __exit__(self, exc_type, exc_val, exc_tb):
self.disable()
def __call__(self, func):
def make_wrapper(f):
@wraps(f)
def wrap(*args, **kwargs):
with self:
return f(*args, **kwargs)
return wrap
return make_wrapper(func)
def __del__(self):
if self._enabled:
msg = (
"ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, "
"MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, "
"EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format(
common.HEARTBEAT_FREQUENCY,
common.MIN_HEARTBEAT_INTERVAL,
common.KILL_CURSOR_FREQUENCY,
common.EVENTS_QUEUE_FREQUENCY,
self._stack,
)
)
self.disable()
raise Exception(msg)
def _all_users(db):
return {u["user"] for u in db.command("usersInfo").get("users", [])}
def sanitize_cmd(cmd):
cp = cmd.copy()
cp.pop("$clusterTime", None)
cp.pop("$db", None)
cp.pop("$readPreference", None)
cp.pop("lsid", None)
if MONGODB_API_VERSION:
# Stable API parameters
cp.pop("apiVersion", None)
# OP_MSG encoding may move the payload type one field to the
# end of the command. Do the same here.
name = next(iter(cp))
try:
identifier = message._FIELD_MAP[name]
docs = cp.pop(identifier)
cp[identifier] = docs
except KeyError:
pass
return cp
def sanitize_reply(reply):
cp = reply.copy()
cp.pop("$clusterTime", None)
cp.pop("operationTime", None)
return cp
def print_thread_tracebacks() -> None:
"""Print all Python thread tracebacks."""
for thread_id, frame in sys._current_frames().items():
sys.stderr.write(f"\n--- Traceback for thread {thread_id} ---\n")
traceback.print_stack(frame, file=sys.stderr)
def print_thread_stacks(pid: int) -> None:
"""Print all C-level thread stacks for a given process id."""
if sys.platform == "darwin":
cmd = ["lldb", "--attach-pid", f"{pid}", "--batch", "--one-line", '"thread backtrace all"']
else:
cmd = ["gdb", f"--pid={pid}", "--batch", '--eval-command="thread apply all bt"']
try:
res = subprocess.run(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8"
)
except Exception as exc:
sys.stderr.write(f"Could not print C-level thread stacks because {cmd[0]} failed: {exc}")
else:
sys.stderr.write(res.stdout)
# Global knobs to speed up the test suite.
global_knobs = client_knobs(events_queue_frequency=0.05)
def _get_executors(topology):
executors = []
for server in topology._servers.values():
# Some MockMonitor do not have an _executor.
if hasattr(server._monitor, "_executor"):
executors.append(server._monitor._executor)
if hasattr(server._monitor, "_rtt_monitor"):
executors.append(server._monitor._rtt_monitor._executor)
executors.append(topology._Topology__events_executor)
if topology._srv_monitor:
executors.append(topology._srv_monitor._executor)
return [e for e in executors if e is not None]
def print_running_topology(topology):
running = [e for e in _get_executors(topology) if not e._stopped]
if running:
print(
"WARNING: found Topology with running threads:\n"
f" Threads: {running}\n"
f" Topology: {topology}\n"
f" Creation traceback:\n{topology._settings._stack}"
)
def test_cases(suite):
"""Iterator over all TestCases within a TestSuite."""
for suite_or_case in suite._tests:
if isinstance(suite_or_case, unittest.TestCase):
# unittest.TestCase
yield suite_or_case
else:
# unittest.TestSuite
yield from test_cases(suite_or_case)
# Helper method to workaround https://bugs.python.org/issue21724
def clear_warning_registry():
"""Clear the __warningregistry__ for all modules."""
for _, module in list(sys.modules.items()):
if hasattr(module, "__warningregistry__"):
module.__warningregistry__ = {} # type:ignore[attr-defined]
class SystemCertsPatcher:
def __init__(self, ca_certs):
if (
ssl.OPENSSL_VERSION.lower().startswith("libressl")
and sys.platform == "darwin"
and not _ssl.IS_PYOPENSSL
):
raise SkipTest(
"LibreSSL on OSX doesn't support setting CA certificates "
"using SSL_CERT_FILE environment variable."
)
self.original_certs = os.environ.get("SSL_CERT_FILE")
# Tell OpenSSL where CA certificates live.
os.environ["SSL_CERT_FILE"] = ca_certs
def disable(self):
if self.original_certs is None:
os.environ.pop("SSL_CERT_FILE")
else:
os.environ["SSL_CERT_FILE"] = self.original_certs