Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2024-09-17 18:41:56 -05:00
commit 54fd7b6104
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
86 changed files with 1736 additions and 1566 deletions

View File

@ -2081,10 +2081,10 @@ axes:
batchtime: 10080 # 7 days
- id: rhel8
display_name: "RHEL 8.x"
run_on: rhel87-small
run_on: rhel8.8-small
batchtime: 10080 # 7 days
- id: rhel92-fips
display_name: "RHEL 9.2 FIPS"
- id: rhel9-fips
display_name: "RHEL 9 FIPS"
run_on: rhel92-fips
batchtime: 10080 # 7 days
- id: ubuntu-22.04
@ -2095,24 +2095,24 @@ axes:
display_name: "Ubuntu 20.04"
run_on: ubuntu2004-small
batchtime: 10080 # 7 days
- id: rhel83-zseries
display_name: "RHEL 8.3 (zSeries)"
run_on: rhel83-zseries-small
- id: rhel8-zseries
display_name: "RHEL 8 (zSeries)"
run_on: rhel8-zseries-small
batchtime: 10080 # 7 days
variables:
SKIP_HATCH: true
- id: rhel81-power8
display_name: "RHEL 8.1 (POWER8)"
run_on: rhel81-power8-small
- id: rhel8-power8
display_name: "RHEL 8 (POWER8)"
run_on: rhel8-power-small
batchtime: 10080 # 7 days
variables:
SKIP_HATCH: true
- id: rhel82-arm64
display_name: "RHEL 8.2 (ARM64)"
- id: rhel8-arm64
display_name: "RHEL 8 (ARM64)"
run_on: rhel82-arm64-small
batchtime: 10080 # 7 days
variables:
- id: windows-64-vsMulti-small
- id: windows
display_name: "Windows 64"
run_on: windows-64-vsMulti-small
batchtime: 10080 # 7 days
@ -2470,7 +2470,7 @@ buildvariants:
- matrix_name: "tests-fips"
matrix_spec:
platform:
- rhel92-fips
- rhel9-fips
auth: "auth"
ssl: "ssl"
display_name: "${platform} ${auth} ${ssl}"
@ -2547,9 +2547,9 @@ buildvariants:
- matrix_name: "test-different-cpu-architectures"
matrix_spec:
platform:
- rhel83-zseries # Added in 5.0.8 (SERVER-44074)
- rhel81-power8 # Added in 4.2.7 (SERVER-44072)
- rhel82-arm64 # Added in 4.4.2 (SERVER-48282)
- rhel8-zseries # Added in 5.0.8 (SERVER-44074)
- rhel8-power8 # Added in 4.2.7 (SERVER-44072)
- rhel8-arm64 # Added in 4.4.2 (SERVER-48282)
auth-ssl: "*"
display_name: "${platform} ${auth-ssl}"
tasks:
@ -2606,7 +2606,7 @@ buildvariants:
- matrix_name: "tests-pyopenssl-windows"
matrix_spec:
platform: windows-64-vsMulti-small
platform: windows
python-version-windows: "*"
auth: "auth"
ssl: "ssl"
@ -2698,7 +2698,7 @@ buildvariants:
- matrix_name: "tests-windows-python-version"
matrix_spec:
platform: windows-64-vsMulti-small
platform: windows
python-version-windows: "*"
auth-ssl: "*"
display_name: "${platform} ${python-version-windows} ${auth-ssl}"
@ -2706,7 +2706,7 @@ buildvariants:
- matrix_name: "tests-windows-python-version-32-bit"
matrix_spec:
platform: windows-64-vsMulti-small
platform: windows
python-version-windows-32: "*"
auth-ssl: "*"
display_name: "${platform} ${python-version-windows-32} ${auth-ssl}"
@ -2724,7 +2724,7 @@ buildvariants:
- matrix_name: "tests-windows-encryption"
matrix_spec:
platform: windows-64-vsMulti-small
platform: windows
python-version-windows: "*"
auth-ssl: "*"
encryption: "*"
@ -2733,7 +2733,7 @@ buildvariants:
rules:
- if:
encryption: ["encryption", "encryption_crypt_shared"]
platform: windows-64-vsMulti-small
platform: windows
python-version-windows: "*"
auth-ssl: "*"
then:
@ -2795,7 +2795,7 @@ buildvariants:
- matrix_name: "tests-windows-enterprise-auth"
matrix_spec:
platform: windows-64-vsMulti-small
platform: windows
python-version-windows: "*"
auth: "auth"
display_name: "Enterprise ${auth} ${platform} ${python-version-windows}"
@ -2907,7 +2907,7 @@ buildvariants:
- matrix_name: "ocsp-test-windows"
matrix_spec:
platform: windows-64-vsMulti-small
platform: windows
python-version-windows: ["3.8", "3.10"]
mongodb-version: ["4.4", "5.0", "6.0", "7.0", "8.0", "latest"]
auth: "noauth"
@ -2932,7 +2932,7 @@ buildvariants:
- matrix_name: "oidc-auth-test"
matrix_spec:
platform: [ rhel8, macos, windows-64-vsMulti-small ]
platform: [ rhel8, macos, windows ]
display_name: "OIDC Auth ${platform}"
tasks:
- name: testoidc_task_group
@ -2981,7 +2981,7 @@ buildvariants:
- matrix_name: "aws-auth-test-windows"
matrix_spec:
platform: [windows-64-vsMulti-small]
platform: [windows]
python-version-windows: "*"
display_name: "MONGODB-AWS Auth ${platform} ${python-version-windows}"
tasks:

View File

@ -163,13 +163,10 @@ hatch run lint:build-manual
## Documentation
To contribute to the [API
documentation](https://pymongo.readthedocs.io/en/stable/) just make your
changes to the inline documentation of the appropriate [source
code](https://github.com/mongodb/mongo-python-driver) or [rst
file](https://github.com/mongodb/mongo-python-driver/tree/master/doc) in
a branch and submit a [pull
request](https://help.github.com/articles/using-pull-requests). You
To contribute to the [API documentation](https://pymongo.readthedocs.io/en/stable/) just make your
changes to the inline documentation of the appropriate [source code](https://github.com/mongodb/mongo-python-driver) or
[rst file](https://github.com/mongodb/mongo-python-driver/tree/master/doc) in
a branch and submit a [pull request](https://help.github.com/articles/using-pull-requests). You
might also use the GitHub
[Edit](https://github.com/blog/844-forking-with-the-edit-button) button.

View File

@ -3,7 +3,7 @@
[![PyPI Version](https://img.shields.io/pypi/v/pymongo)](https://pypi.org/project/pymongo)
[![Python Versions](https://img.shields.io/pypi/pyversions/pymongo)](https://pypi.org/project/pymongo)
[![Monthly Downloads](https://static.pepy.tech/badge/pymongo/month)](https://pepy.tech/project/pymongo)
[![Documentation Status](https://readthedocs.org/projects/pymongo/badge/?version=stable)](http://pymongo.readthedocs.io/en/stable/?badge=stable)
[![API Documentation Status](https://readthedocs.org/projects/pymongo/badge/?version=stable)](http://pymongo.readthedocs.io/en/stable/api?badge=stable)
## About

View File

@ -306,7 +306,7 @@ static PyObject* datetime_from_millis(long long millis) {
if (evalue) {
PyObject* err_msg = PyObject_Str(evalue);
if (err_msg) {
PyObject* appendage = PyUnicode_FromString(" (Consider Using CodecOptions(datetime_conversion=DATETIME_AUTO) or MongoClient(datetime_conversion='DATETIME_AUTO')). See: https://pymongo.readthedocs.io/en/stable/examples/datetimes.html#handling-out-of-range-datetimes");
PyObject* appendage = PyUnicode_FromString(" (Consider Using CodecOptions(datetime_conversion=DATETIME_AUTO) or MongoClient(datetime_conversion='DATETIME_AUTO')). See: https://www.mongodb.com/docs/languages/python/pymongo-driver/current/data-formats/dates-and-times/#handling-out-of-range-datetimes");
if (appendage) {
PyObject* msg = PyUnicode_Concat(err_msg, appendage);
if (msg) {

View File

@ -31,7 +31,7 @@ EPOCH_NAIVE = EPOCH_AWARE.replace(tzinfo=None)
_DATETIME_ERROR_SUGGESTION = (
"(Consider Using CodecOptions(datetime_conversion=DATETIME_AUTO)"
" or MongoClient(datetime_conversion='DATETIME_AUTO'))."
" See: https://pymongo.readthedocs.io/en/stable/examples/datetimes.html#handling-out-of-range-datetimes"
" See: https://www.mongodb.com/docs/languages/python/pymongo-driver/current/data-formats/dates-and-times/#handling-out-of-range-datetimes"
)

View File

@ -93,6 +93,10 @@ Unavoidable breaking changes
- Since we are now using ``hatch`` as our build backend, we no longer have a usable ``setup.py`` file
and require installation using ``pip``. Attempts to invoke the ``setup.py`` file will raise an exception.
Additionally, ``pip`` >= 21.3 is now required for editable installs.
- We no longer support the ``srv`` extra, since ``dnspython`` is included as a dependency in PyMongo 4.7+.
Instead of ``pip install pymongo[srv]``, use ``pip install pymongo``.
- We no longer support the ``tls`` extra, which was only valid for Python 2.
Instead of ``pip install pymongo[tls]``, use ``pip install pymongo``.
Issues Resolved
...............

View File

@ -1,6 +1,11 @@
PyMongo |release| Documentation
===============================
.. note:: The PyMongo documentation has been migrated to the
`MongoDB Documentation site <https://www.mongodb.com/docs/languages/python/pymongo-driver/current>`_.
As of PyMongo 4.10, the ReadTheDocs site will contain the detailed changelog and API docs, while the
rest of the documentation will only appear on the MongoDB Documentation site.
Overview
--------
**PyMongo** is a Python distribution containing tools for working with
@ -95,8 +100,6 @@ pull request.
Changes
-------
See the :doc:`changelog` for a full list of changes to PyMongo.
For older versions of the documentation please see the
`archive list <http://api.mongodb.org/python/>`_.
About This Documentation
------------------------

View File

@ -29,6 +29,7 @@ from typing import (
)
from bson import CodecOptions, _convert_raw_document_lists_to_streams
from pymongo import _csot
from pymongo.asynchronous.cursor import _ConnectionManager
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
@ -77,6 +78,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
self._address = address
self._batch_size = batch_size
self._max_await_time_ms = max_await_time_ms
self._timeout = self._collection.database.client.options.timeout
self._session = session
self._explicit_session = explicit_session
self._killed = self._id == 0
@ -385,6 +387,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
@_csot.apply
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.

View File

@ -36,7 +36,7 @@ from typing import (
from bson import RE_TYPE, _convert_raw_document_lists_to_streams
from bson.code import Code
from bson.son import SON
from pymongo import helpers_shared
from pymongo import _csot, helpers_shared
from pymongo.asynchronous.helpers import anext
from pymongo.collation import validate_collation_or_none
from pymongo.common import (
@ -196,6 +196,7 @@ class AsyncCursor(Generic[_DocumentType]):
self._explain = False
self._comment = comment
self._max_time_ms = max_time_ms
self._timeout = self._collection.database.client.options.timeout
self._max_await_time_ms: Optional[int] = None
self._max: Optional[Union[dict[Any, Any], _Sort]] = max
self._min: Optional[Union[dict[Any, Any], _Sort]] = min
@ -1290,6 +1291,7 @@ class AsyncCursor(Generic[_DocumentType]):
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
@_csot.apply
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.

View File

@ -1193,7 +1193,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
),
ResourceWarning,
stacklevel=2,
source=self,
)
except AttributeError:
pass

View File

@ -227,8 +227,9 @@ class Topology:
warnings.warn( # type: ignore[call-overload] # noqa: B028
"AsyncMongoClient opened before fork. May not be entirely fork-safe, "
"proceed with caution. See PyMongo's documentation for details: "
"https://pymongo.readthedocs.io/en/stable/faq.html#"
"is-pymongo-fork-safe",
"https://www.mongodb.com/docs/languages/"
"python/pymongo-driver/current/faq/"
"#is-pymongo-fork-safe-",
**kwargs,
)
async with self._lock:

View File

@ -29,6 +29,7 @@ from typing import (
)
from bson import CodecOptions, _convert_raw_document_lists_to_streams
from pymongo import _csot
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.message import (
@ -77,6 +78,7 @@ class CommandCursor(Generic[_DocumentType]):
self._address = address
self._batch_size = batch_size
self._max_await_time_ms = max_await_time_ms
self._timeout = self._collection.database.client.options.timeout
self._session = session
self._explicit_session = explicit_session
self._killed = self._id == 0
@ -385,6 +387,7 @@ class CommandCursor(Generic[_DocumentType]):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
@_csot.apply
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.

View File

@ -36,7 +36,7 @@ from typing import (
from bson import RE_TYPE, _convert_raw_document_lists_to_streams
from bson.code import Code
from bson.son import SON
from pymongo import helpers_shared
from pymongo import _csot, helpers_shared
from pymongo.collation import validate_collation_or_none
from pymongo.common import (
validate_is_document_type,
@ -196,6 +196,7 @@ class Cursor(Generic[_DocumentType]):
self._explain = False
self._comment = comment
self._max_time_ms = max_time_ms
self._timeout = self._collection.database.client.options.timeout
self._max_await_time_ms: Optional[int] = None
self._max: Optional[Union[dict[Any, Any], _Sort]] = max
self._min: Optional[Union[dict[Any, Any], _Sort]] = min
@ -1288,6 +1289,7 @@ class Cursor(Generic[_DocumentType]):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
@_csot.apply
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.

View File

@ -1193,7 +1193,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
),
ResourceWarning,
stacklevel=2,
source=self,
)
except AttributeError:
pass

View File

@ -227,8 +227,9 @@ class Topology:
warnings.warn( # type: ignore[call-overload] # noqa: B028
"MongoClient opened before fork. May not be entirely fork-safe, "
"proceed with caution. See PyMongo's documentation for details: "
"https://pymongo.readthedocs.io/en/stable/faq.html#"
"is-pymongo-fork-safe",
"https://www.mongodb.com/docs/languages/"
"python/pymongo-driver/current/faq/"
"#is-pymongo-fork-safe-",
**kwargs,
)
with self._lock:

View File

@ -42,7 +42,7 @@ classifiers = [
[project.urls]
Homepage = "https://www.mongodb.org"
Documentation = "https://pymongo.readthedocs.io"
Documentation = "https://www.mongodb.com/docs/languages/python/pymongo-driver/current/"
Source = "https://github.com/mongodb/mongo-python-driver"
Tracker = "https://jira.mongodb.org/projects/PYTHON/issues"
@ -96,9 +96,6 @@ filterwarnings = [
"module:please use dns.resolver.Resolver.resolve:DeprecationWarning",
# https://github.com/dateutil/dateutil/issues/1314
"module:datetime.datetime.utc:DeprecationWarning:dateutil",
# TODO: Remove both of these in https://jira.mongodb.org/browse/PYTHON-4731
"ignore:Unclosed AsyncMongoClient*",
"ignore:Unclosed MongoClient*",
]
markers = [
"auth_aws: tests that rely on pymongo-auth-aws",

View File

@ -16,8 +16,6 @@
from __future__ import annotations
import asyncio
import base64
import contextlib
import gc
import multiprocessing
import os
@ -27,7 +25,6 @@ import subprocess
import sys
import threading
import time
import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
@ -54,6 +51,8 @@ from test.helpers import (
sanitize_reply,
)
from pymongo.uri_parser import parse_uri
try:
import ipaddress
@ -80,6 +79,12 @@ from pymongo.synchronous.mongo_client import MongoClient
_IS_SYNC = True
def _connection_string(h):
if h.startswith(("mongodb://", "mongodb+srv://")):
return h
return f"mongodb://{h!s}"
class ClientContext:
client: MongoClient
@ -230,6 +235,9 @@ class ClientContext:
if not self._check_user_provided():
_create_user(self.client.admin, db_user, db_pwd)
if self.client:
self.client.close()
self.client = self._connect(
host,
port,
@ -256,6 +264,8 @@ class ClientContext:
if "setName" in hello:
self.replica_set_name = str(hello["setName"])
self.is_rs = True
if self.client:
self.client.close()
if self.auth_enabled:
# It doesn't matter which member we use as the seed here.
self.client = pymongo.MongoClient(
@ -318,6 +328,7 @@ class ClientContext:
hello = mongos_client.admin.command(HelloCompat.LEGACY_CMD)
if hello.get("msg") == "isdbgrid":
self.mongoses.append(next_address)
mongos_client.close()
def init(self):
with self.conn_lock:
@ -537,12 +548,6 @@ class ClientContext:
lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func
)
def require_no_fips(self, func):
"""Run a test only if the host does not have FIPS enabled."""
return self._require(
lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func
)
def require_no_auth(self, func):
"""Run a test only if the server is running without auth enabled."""
return self._require(
@ -930,6 +935,172 @@ class PyMongoTestCase(unittest.TestCase):
self.fail(f"child timed out after {timeout}s (see traceback in logs): deadlock?")
self.assertEqual(proc.exitcode, 0)
@classmethod
def _unmanaged_async_mongo_client(
cls, host, port, authenticate=True, directConnection=None, **kwargs
):
"""Create a new client over SSL/TLS if necessary."""
host = host or client_context.host
port = port or client_context.port
client_options: dict = client_context.default_client_options.copy()
if client_context.replica_set_name and not directConnection:
client_options["replicaSet"] = client_context.replica_set_name
if directConnection is not None:
client_options["directConnection"] = directConnection
client_options.update(kwargs)
uri = _connection_string(host)
auth_mech = kwargs.get("authMechanism", "")
if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
# Only add the default username or password if one is not provided.
res = parse_uri(uri)
if (
not res["username"]
and not res["password"]
and "username" not in client_options
and "password" not in client_options
):
client_options["username"] = db_user
client_options["password"] = db_pwd
client = MongoClient(uri, port, **client_options)
if client._options.connect:
client._connect()
return client
def _async_mongo_client(self, host, port, authenticate=True, directConnection=None, **kwargs):
"""Create a new client over SSL/TLS if necessary."""
host = host or client_context.host
port = port or client_context.port
client_options: dict = client_context.default_client_options.copy()
if client_context.replica_set_name and not directConnection:
client_options["replicaSet"] = client_context.replica_set_name
if directConnection is not None:
client_options["directConnection"] = directConnection
client_options.update(kwargs)
uri = _connection_string(host)
auth_mech = kwargs.get("authMechanism", "")
if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
# Only add the default username or password if one is not provided.
res = parse_uri(uri)
if (
not res["username"]
and not res["password"]
and "username" not in client_options
and "password" not in client_options
):
client_options["username"] = db_user
client_options["password"] = db_pwd
client = MongoClient(uri, port, **client_options)
if client._options.connect:
client._connect()
self.addCleanup(client.close)
return client
@classmethod
def unmanaged_single_client_noauth(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> MongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return cls._unmanaged_async_mongo_client(
h, p, authenticate=False, directConnection=True, **kwargs
)
@classmethod
def unmanaged_single_client(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> MongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return cls._unmanaged_async_mongo_client(h, p, directConnection=True, **kwargs)
@classmethod
def unmanaged_rs_client(cls, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Connect to the replica set and authenticate if necessary."""
return cls._unmanaged_async_mongo_client(h, p, **kwargs)
@classmethod
def unmanaged_rs_client_noauth(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> MongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs)
@classmethod
def unmanaged_rs_or_single_client_noauth(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> MongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs)
@classmethod
def unmanaged_rs_or_single_client(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> MongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return cls._unmanaged_async_mongo_client(h, p, **kwargs)
def single_client_noauth(
self, h: Any = None, p: Any = None, **kwargs: Any
) -> MongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return self._async_mongo_client(h, p, authenticate=False, directConnection=True, **kwargs)
def single_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Make a direct connection, and authenticate if necessary."""
return self._async_mongo_client(h, p, directConnection=True, **kwargs)
def rs_client_noauth(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Connect to the replica set. Don't authenticate."""
return self._async_mongo_client(h, p, authenticate=False, **kwargs)
def rs_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Connect to the replica set and authenticate if necessary."""
return self._async_mongo_client(h, p, **kwargs)
def rs_or_single_client_noauth(
self, h: Any = None, p: Any = None, **kwargs: Any
) -> MongoClient[dict]:
"""Connect to the replica set if there is one, otherwise the standalone.
Like rs_or_single_client, but does not authenticate.
"""
return self._async_mongo_client(h, p, authenticate=False, **kwargs)
def rs_or_single_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[Any]:
"""Connect to the replica set if there is one, otherwise the standalone.
Authenticates if necessary.
"""
return self._async_mongo_client(h, p, **kwargs)
def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient:
if not h and not p:
client = MongoClient(**kwargs)
else:
client = MongoClient(h, p, **kwargs)
self.addCleanup(client.close)
return client
@classmethod
def unmanaged_simple_client(cls, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient:
if not h and not p:
client = MongoClient(**kwargs)
else:
client = MongoClient(h, p, **kwargs)
return client
def disable_replication(self, client):
"""Disable replication on all secondaries."""
for h, p in client.secondaries:
secondary = self.single_client(h, p)
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn")
def enable_replication(self, client):
"""Enable replication on all secondaries."""
for h, p in client.secondaries:
secondary = self.single_client(h, p)
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off")
class UnitTest(PyMongoTestCase):
"""Async base class for TestCases that don't require a connection to MongoDB."""

View File

@ -16,8 +16,6 @@
from __future__ import annotations
import asyncio
import base64
import contextlib
import gc
import multiprocessing
import os
@ -27,7 +25,6 @@ import subprocess
import sys
import threading
import time
import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
@ -54,6 +51,8 @@ from test.helpers import (
sanitize_reply,
)
from pymongo.uri_parser import parse_uri
try:
import ipaddress
@ -80,6 +79,12 @@ from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
_IS_SYNC = False
def _connection_string(h):
if h.startswith(("mongodb://", "mongodb+srv://")):
return h
return f"mongodb://{h!s}"
class AsyncClientContext:
client: AsyncMongoClient
@ -230,6 +235,9 @@ class AsyncClientContext:
if not await self._check_user_provided():
await _create_user(self.client.admin, db_user, db_pwd)
if self.client:
await self.client.close()
self.client = await self._connect(
host,
port,
@ -256,6 +264,8 @@ class AsyncClientContext:
if "setName" in hello:
self.replica_set_name = str(hello["setName"])
self.is_rs = True
if self.client:
await self.client.close()
if self.auth_enabled:
# It doesn't matter which member we use as the seed here.
self.client = pymongo.AsyncMongoClient(
@ -320,6 +330,7 @@ class AsyncClientContext:
hello = await mongos_client.admin.command(HelloCompat.LEGACY_CMD)
if hello.get("msg") == "isdbgrid":
self.mongoses.append(next_address)
await mongos_client.close()
async def init(self):
with self.conn_lock:
@ -539,12 +550,6 @@ class AsyncClientContext:
lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func
)
def require_no_fips(self, func):
"""Run a test only if the host does not have FIPS enabled."""
return self._require(
lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func
)
def require_no_auth(self, func):
"""Run a test only if the server is running without auth enabled."""
return self._require(
@ -932,6 +937,188 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
self.fail(f"child timed out after {timeout}s (see traceback in logs): deadlock?")
self.assertEqual(proc.exitcode, 0)
@classmethod
async def _unmanaged_async_mongo_client(
cls, host, port, authenticate=True, directConnection=None, **kwargs
):
"""Create a new client over SSL/TLS if necessary."""
host = host or await async_client_context.host
port = port or await async_client_context.port
client_options: dict = async_client_context.default_client_options.copy()
if async_client_context.replica_set_name and not directConnection:
client_options["replicaSet"] = async_client_context.replica_set_name
if directConnection is not None:
client_options["directConnection"] = directConnection
client_options.update(kwargs)
uri = _connection_string(host)
auth_mech = kwargs.get("authMechanism", "")
if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
# Only add the default username or password if one is not provided.
res = parse_uri(uri)
if (
not res["username"]
and not res["password"]
and "username" not in client_options
and "password" not in client_options
):
client_options["username"] = db_user
client_options["password"] = db_pwd
client = AsyncMongoClient(uri, port, **client_options)
if client._options.connect:
await client.aconnect()
return client
async def _async_mongo_client(
self, host, port, authenticate=True, directConnection=None, **kwargs
):
"""Create a new client over SSL/TLS if necessary."""
host = host or await async_client_context.host
port = port or await async_client_context.port
client_options: dict = async_client_context.default_client_options.copy()
if async_client_context.replica_set_name and not directConnection:
client_options["replicaSet"] = async_client_context.replica_set_name
if directConnection is not None:
client_options["directConnection"] = directConnection
client_options.update(kwargs)
uri = _connection_string(host)
auth_mech = kwargs.get("authMechanism", "")
if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
# Only add the default username or password if one is not provided.
res = parse_uri(uri)
if (
not res["username"]
and not res["password"]
and "username" not in client_options
and "password" not in client_options
):
client_options["username"] = db_user
client_options["password"] = db_pwd
client = AsyncMongoClient(uri, port, **client_options)
if client._options.connect:
await client.aconnect()
self.addAsyncCleanup(client.close)
return client
@classmethod
async def unmanaged_async_single_client_noauth(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return await cls._unmanaged_async_mongo_client(
h, p, authenticate=False, directConnection=True, **kwargs
)
@classmethod
async def unmanaged_async_single_client(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return await cls._unmanaged_async_mongo_client(h, p, directConnection=True, **kwargs)
@classmethod
async def unmanaged_async_rs_client(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Connect to the replica set and authenticate if necessary."""
return await cls._unmanaged_async_mongo_client(h, p, **kwargs)
@classmethod
async def unmanaged_async_rs_client_noauth(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return await cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs)
@classmethod
async def unmanaged_async_rs_or_single_client_noauth(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return await cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs)
@classmethod
async def unmanaged_async_rs_or_single_client(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return await cls._unmanaged_async_mongo_client(h, p, **kwargs)
async def async_single_client_noauth(
self, h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return await self._async_mongo_client(
h, p, authenticate=False, directConnection=True, **kwargs
)
async def async_single_client(
self, h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Make a direct connection, and authenticate if necessary."""
return await self._async_mongo_client(h, p, directConnection=True, **kwargs)
async def async_rs_client_noauth(
self, h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Connect to the replica set. Don't authenticate."""
return await self._async_mongo_client(h, p, authenticate=False, **kwargs)
async def async_rs_client(
self, h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Connect to the replica set and authenticate if necessary."""
return await self._async_mongo_client(h, p, **kwargs)
async def async_rs_or_single_client_noauth(
self, h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Connect to the replica set if there is one, otherwise the standalone.
Like rs_or_single_client, but does not authenticate.
"""
return await self._async_mongo_client(h, p, authenticate=False, **kwargs)
async def async_rs_or_single_client(
self, h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[Any]:
"""Connect to the replica set if there is one, otherwise the standalone.
Authenticates if necessary.
"""
return await self._async_mongo_client(h, p, **kwargs)
def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> AsyncMongoClient:
if not h and not p:
client = AsyncMongoClient(**kwargs)
else:
client = AsyncMongoClient(h, p, **kwargs)
self.addAsyncCleanup(client.close)
return client
@classmethod
def unmanaged_simple_client(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient:
if not h and not p:
client = AsyncMongoClient(**kwargs)
else:
client = AsyncMongoClient(h, p, **kwargs)
return client
async def disable_replication(self, client):
"""Disable replication on all secondaries."""
for h, p in client.secondaries:
secondary = await self.async_single_client(h, p)
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn")
async def enable_replication(self, client):
"""Enable replication on all secondaries."""
for h, p in client.secondaries:
secondary = await self.async_single_client(h, p)
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off")
class AsyncUnitTest(AsyncPyMongoTestCase):
"""Async base class for TestCases that don't require a connection to MongoDB."""

View File

@ -23,16 +23,14 @@ from urllib.parse import quote_plus
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, SkipTest, async_client_context, unittest
from test.utils import (
AllowListEventListener,
async_rs_or_single_client,
async_rs_or_single_client_noauth,
async_single_client,
async_single_client_noauth,
delay,
ignore_deprecations,
from test.asynchronous import (
AsyncIntegrationTest,
AsyncPyMongoTestCase,
SkipTest,
async_client_context,
unittest,
)
from test.utils import AllowListEventListener, delay, ignore_deprecations
from pymongo import AsyncMongoClient, monitoring
from pymongo.asynchronous.auth import HAVE_KERBEROS
@ -81,7 +79,7 @@ class AutoAuthenticateThread(threading.Thread):
self.success = True
class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
class TestGSSAPI(AsyncPyMongoTestCase):
mech_properties: str
service_realm_required: bool
@ -138,7 +136,7 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
if not self.service_realm_required:
# Without authMechanismProperties.
client = AsyncMongoClient(
client = self.simple_client(
GSSAPI_HOST,
GSSAPI_PORT,
username=GSSAPI_PRINCIPAL,
@ -149,11 +147,11 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
await client[GSSAPI_DB].collection.find_one()
# Log in using URI, without authMechanismProperties.
client = AsyncMongoClient(uri)
client = self.simple_client(uri)
await client[GSSAPI_DB].collection.find_one()
# Authenticate with authMechanismProperties.
client = AsyncMongoClient(
client = self.simple_client(
GSSAPI_HOST,
GSSAPI_PORT,
username=GSSAPI_PRINCIPAL,
@ -166,14 +164,14 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
# Log in using URI, with authMechanismProperties.
mech_uri = uri + f"&authMechanismProperties={self.mech_properties}"
client = AsyncMongoClient(mech_uri)
client = self.simple_client(mech_uri)
await client[GSSAPI_DB].collection.find_one()
set_name = async_client_context.replica_set_name
if set_name:
if not self.service_realm_required:
# Without authMechanismProperties
client = AsyncMongoClient(
client = self.simple_client(
GSSAPI_HOST,
GSSAPI_PORT,
username=GSSAPI_PRINCIPAL,
@ -185,11 +183,11 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
await client[GSSAPI_DB].list_collection_names()
uri = uri + f"&replicaSet={set_name!s}"
client = AsyncMongoClient(uri)
client = self.simple_client(uri)
await client[GSSAPI_DB].list_collection_names()
# With authMechanismProperties
client = AsyncMongoClient(
client = self.simple_client(
GSSAPI_HOST,
GSSAPI_PORT,
username=GSSAPI_PRINCIPAL,
@ -202,13 +200,13 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
await client[GSSAPI_DB].list_collection_names()
mech_uri = mech_uri + f"&replicaSet={set_name!s}"
client = AsyncMongoClient(mech_uri)
client = self.simple_client(mech_uri)
await client[GSSAPI_DB].list_collection_names()
@ignore_deprecations
@async_client_context.require_sync
async def test_gssapi_threaded(self):
client = AsyncMongoClient(
client = self.simple_client(
GSSAPI_HOST,
GSSAPI_PORT,
username=GSSAPI_PRINCIPAL,
@ -244,7 +242,7 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
set_name = async_client_context.replica_set_name
if set_name:
client = AsyncMongoClient(
client = self.simple_client(
GSSAPI_HOST,
GSSAPI_PORT,
username=GSSAPI_PRINCIPAL,
@ -267,14 +265,14 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
self.assertTrue(thread.success)
class TestSASLPlain(unittest.IsolatedAsyncioTestCase):
class TestSASLPlain(AsyncPyMongoTestCase):
@classmethod
def setUpClass(cls):
if not SASL_HOST or not SASL_USER or not SASL_PASS:
raise SkipTest("Must set SASL_HOST, SASL_USER, and SASL_PASS to test SASL")
async def test_sasl_plain(self):
client = AsyncMongoClient(
client = self.simple_client(
SASL_HOST,
SASL_PORT,
username=SASL_USER,
@ -293,12 +291,12 @@ class TestSASLPlain(unittest.IsolatedAsyncioTestCase):
SASL_PORT,
SASL_DB,
)
client = AsyncMongoClient(uri)
client = self.simple_client(uri)
await client.ldap.test.find_one()
set_name = async_client_context.replica_set_name
if set_name:
client = AsyncMongoClient(
client = self.simple_client(
SASL_HOST,
SASL_PORT,
replicaSet=set_name,
@ -317,7 +315,7 @@ class TestSASLPlain(unittest.IsolatedAsyncioTestCase):
SASL_DB,
str(set_name),
)
client = AsyncMongoClient(uri)
client = self.simple_client(uri)
await client.ldap.test.find_one()
async def test_sasl_plain_bad_credentials(self):
@ -331,8 +329,8 @@ class TestSASLPlain(unittest.IsolatedAsyncioTestCase):
)
return uri
bad_user = AsyncMongoClient(auth_string("not-user", SASL_PASS))
bad_pwd = AsyncMongoClient(auth_string(SASL_USER, "not-pwd"))
bad_user = self.simple_client(auth_string("not-user", SASL_PASS))
bad_pwd = self.simple_client(auth_string(SASL_USER, "not-pwd"))
# OperationFailure raised upon connecting.
with self.assertRaises(OperationFailure):
await bad_user.admin.command("ping")
@ -356,7 +354,7 @@ class TestSCRAMSHA1(AsyncIntegrationTest):
async def test_scram_sha1(self):
host, port = await async_client_context.host, await async_client_context.port
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
"mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" % (host, port)
)
await client.pymongo_test.command("dbstats")
@ -367,7 +365,7 @@ class TestSCRAMSHA1(AsyncIntegrationTest):
"@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1"
"&replicaSet=%s" % (host, port, async_client_context.replica_set_name)
)
client = await async_single_client_noauth(uri)
client = await self.async_single_client_noauth(uri)
await client.pymongo_test.command("dbstats")
db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY)
await db.command("dbstats")
@ -395,7 +393,7 @@ class TestSCRAM(AsyncIntegrationTest):
"testscram", "sha256", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"]
)
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="sha256", password="pwd", authSource="testscram", event_listeners=[listener]
)
await client.testscram.command("dbstats")
@ -432,38 +430,38 @@ class TestSCRAM(AsyncIntegrationTest):
)
# Step 2: verify auth success cases
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="sha1", password="pwd", authSource="testscram"
)
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1"
)
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="sha256", password="pwd", authSource="testscram"
)
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256"
)
await client.testscram.command("dbstats")
# Step 2: SCRAM-SHA-1 and SCRAM-SHA-256
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1"
)
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256"
)
await client.testscram.command("dbstats")
self.listener.reset()
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="both", password="pwd", authSource="testscram", event_listeners=[self.listener]
)
await client.testscram.command("dbstats")
@ -476,19 +474,19 @@ class TestSCRAM(AsyncIntegrationTest):
self.assertEqual(started.command.get("mechanism"), "SCRAM-SHA-256")
# Step 3: verify auth failure conditions
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256"
)
with self.assertRaises(OperationFailure):
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1"
)
with self.assertRaises(OperationFailure):
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="not-a-user", password="pwd", authSource="testscram"
)
with self.assertRaises(OperationFailure):
@ -501,7 +499,7 @@ class TestSCRAM(AsyncIntegrationTest):
port,
async_client_context.replica_set_name,
)
client = await async_single_client_noauth(uri)
client = await self.async_single_client_noauth(uri)
await client.testscram.command("dbstats")
db = client.get_database("testscram", read_preference=ReadPreference.SECONDARY)
await db.command("dbstats")
@ -521,12 +519,12 @@ class TestSCRAM(AsyncIntegrationTest):
"testscram", "IX", "IX", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"]
)
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="\u2168", password="\u2163", authSource="testscram"
)
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="\u2168",
password="\u2163",
authSource="testscram",
@ -534,17 +532,17 @@ class TestSCRAM(AsyncIntegrationTest):
)
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="\u2168", password="IV", authSource="testscram"
)
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="IX", password="I\u00ADX", authSource="testscram"
)
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="IX",
password="I\u00ADX",
authSource="testscram",
@ -552,31 +550,31 @@ class TestSCRAM(AsyncIntegrationTest):
)
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
username="IX", password="IX", authSource="testscram", authMechanism="SCRAM-SHA-256"
)
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
"mongodb://\u2168:\u2163@%s:%d/testscram" % (host, port)
)
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
"mongodb://\u2168:IV@%s:%d/testscram" % (host, port)
)
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
"mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port)
)
await client.testscram.command("dbstats")
client = await async_rs_or_single_client_noauth(
client = await self.async_rs_or_single_client_noauth(
"mongodb://IX:IX@%s:%d/testscram" % (host, port)
)
await client.testscram.command("dbstats")
async def test_cache(self):
client = await async_single_client()
client = await self.async_single_client()
credentials = client.options.pool_options._credentials
cache = credentials.cache
self.assertIsNotNone(cache)
@ -601,8 +599,7 @@ class TestSCRAM(AsyncIntegrationTest):
await coll.insert_one({"_id": 1})
# The first thread to call find() will authenticate
client = await async_rs_or_single_client()
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client()
coll = client.db.test
threads = []
for _ in range(4):
@ -631,7 +628,9 @@ class TestAuthURIOptions(AsyncIntegrationTest):
async def test_uri_options(self):
# Test default to admin
host, port = await async_client_context.host, await async_client_context.port
client = await async_rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))
client = await self.async_rs_or_single_client_noauth(
"mongodb://admin:pass@%s:%d" % (host, port)
)
self.assertTrue(await client.admin.command("dbstats"))
if async_client_context.is_rs:
@ -640,14 +639,14 @@ class TestAuthURIOptions(AsyncIntegrationTest):
port,
async_client_context.replica_set_name,
)
client = await async_single_client_noauth(uri)
client = await self.async_single_client_noauth(uri)
self.assertTrue(await client.admin.command("dbstats"))
db = client.get_database("admin", read_preference=ReadPreference.SECONDARY)
self.assertTrue(await db.command("dbstats"))
# Test explicit database
uri = "mongodb://user:pass@%s:%d/pymongo_test" % (host, port)
client = await async_rs_or_single_client_noauth(uri)
client = await self.async_rs_or_single_client_noauth(uri)
with self.assertRaises(OperationFailure):
await client.admin.command("dbstats")
self.assertTrue(await client.pymongo_test.command("dbstats"))
@ -658,7 +657,7 @@ class TestAuthURIOptions(AsyncIntegrationTest):
port,
async_client_context.replica_set_name,
)
client = await async_single_client_noauth(uri)
client = await self.async_single_client_noauth(uri)
with self.assertRaises(OperationFailure):
await client.admin.command("dbstats")
self.assertTrue(await client.pymongo_test.command("dbstats"))
@ -667,7 +666,7 @@ class TestAuthURIOptions(AsyncIntegrationTest):
# Test authSource
uri = "mongodb://user:pass@%s:%d/pymongo_test2?authSource=pymongo_test" % (host, port)
client = await async_rs_or_single_client_noauth(uri)
client = await self.async_rs_or_single_client_noauth(uri)
with self.assertRaises(OperationFailure):
await client.pymongo_test2.command("dbstats")
self.assertTrue(await client.pymongo_test.command("dbstats"))
@ -677,7 +676,7 @@ class TestAuthURIOptions(AsyncIntegrationTest):
"mongodb://user:pass@%s:%d/pymongo_test2?replicaSet="
"%s;authSource=pymongo_test" % (host, port, async_client_context.replica_set_name)
)
client = await async_single_client_noauth(uri)
client = await self.async_single_client_noauth(uri)
with self.assertRaises(OperationFailure):
await client.pymongo_test2.command("dbstats")
self.assertTrue(await client.pymongo_test.command("dbstats"))

View File

@ -20,6 +20,7 @@ import json
import os
import sys
import warnings
from test.asynchronous import AsyncPyMongoTestCase
sys.path[0:0] = [""]
@ -34,7 +35,7 @@ _IS_SYNC = False
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth")
class TestAuthSpec(unittest.IsolatedAsyncioTestCase):
class TestAuthSpec(AsyncPyMongoTestCase):
pass
@ -54,7 +55,7 @@ def create_test(test_case):
warnings.simplefilter("default")
self.assertRaises(Exception, AsyncMongoClient, uri, connect=False)
else:
client = AsyncMongoClient(uri, connect=False)
client = self.simple_client(uri, connect=False)
credentials = client.options.pool_options._credentials
if credential is None:
self.assertIsNone(credentials)

View File

@ -24,23 +24,14 @@ from pymongo.asynchronous.mongo_client import AsyncMongoClient
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, remove_all_users, unittest
from test.utils import (
async_rs_or_single_client_noauth,
async_single_client,
async_wait_until,
)
from test.utils import async_wait_until
from bson.binary import Binary, UuidRepresentation
from bson.codec_options import CodecOptions
from bson.objectid import ObjectId
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.common import partition_node
from pymongo.errors import (
BulkWriteError,
ConfigurationError,
InvalidOperation,
OperationFailure,
)
from pymongo.errors import BulkWriteError, ConfigurationError, InvalidOperation, OperationFailure
from pymongo.operations import *
from pymongo.write_concern import WriteConcern
@ -915,7 +906,7 @@ class AsyncTestBulkAuthorization(AsyncBulkAuthorizationTestBase):
async def test_readonly(self):
# We test that an authorization failure aborts the batch and is raised
# as OperationFailure.
cli = await async_rs_or_single_client_noauth(
cli = await self.async_rs_or_single_client_noauth(
username="readonly", password="pw", authSource="pymongo_test"
)
coll = cli.pymongo_test.test
@ -926,7 +917,7 @@ class AsyncTestBulkAuthorization(AsyncBulkAuthorizationTestBase):
async def test_no_remove(self):
# We test that an authorization failure aborts the batch and is raised
# as OperationFailure.
cli = await async_rs_or_single_client_noauth(
cli = await self.async_rs_or_single_client_noauth(
username="noremove", password="pw", authSource="pymongo_test"
)
coll = cli.pymongo_test.test
@ -954,7 +945,7 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
if cls.w is not None and cls.w > 1:
for member in (await async_client_context.hello)["hosts"]:
if member != (await async_client_context.hello)["primary"]:
cls.secondary = await async_single_client(*partition_node(member))
cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member))
break
@classmethod

View File

@ -28,12 +28,17 @@ from typing import no_type_check
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, Version, async_client_context, unittest
from test.asynchronous import (
AsyncIntegrationTest,
AsyncPyMongoTestCase,
Version,
async_client_context,
unittest,
)
from test.unified_format import generate_test_classes
from test.utils import (
AllowListEventListener,
EventListener,
async_rs_or_single_client,
async_wait_until,
)
@ -69,8 +74,7 @@ class TestAsyncChangeStreamBase(AsyncIntegrationTest):
async def client_with_listener(self, *commands):
"""Return a client with a AllowListEventListener."""
listener = AllowListEventListener(*commands)
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
return client, listener
def watched_collection(self, *args, **kwargs):
@ -176,7 +180,7 @@ class APITestsMixin:
@no_type_check
async def test_try_next_runs_one_getmore(self):
listener = EventListener()
client = await async_rs_or_single_client(event_listeners=[listener])
client = await self.async_rs_or_single_client(event_listeners=[listener])
# Connect to the cluster.
await client.admin.command("ping")
listener.reset()
@ -234,7 +238,7 @@ class APITestsMixin:
@no_type_check
async def test_batch_size_is_honored(self):
listener = EventListener()
client = await async_rs_or_single_client(event_listeners=[listener])
client = await self.async_rs_or_single_client(event_listeners=[listener])
# Connect to the cluster.
await client.admin.command("ping")
listener.reset()
@ -481,7 +485,9 @@ class ProseSpecTestsMixin:
@no_type_check
async def _client_with_listener(self, *commands):
listener = AllowListEventListener(*commands)
client = await async_rs_or_single_client(event_listeners=[listener])
client = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client(
event_listeners=[listener]
)
self.addAsyncCleanup(client.close)
return client, listener
@ -1131,7 +1137,7 @@ class TestAllLegacyScenarios(AsyncIntegrationTest):
async def _setup_class(cls):
await super()._setup_class()
cls.listener = AllowListEventListener("aggregate", "getMore")
cls.client = await async_rs_or_single_client(event_listeners=[cls.listener])
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
@classmethod
async def _tearDown_class(cls):

File diff suppressed because it is too large Load Diff

View File

@ -27,7 +27,6 @@ from test.asynchronous import (
)
from test.utils import (
OvertCommandListener,
async_rs_or_single_client,
)
from unittest.mock import patch
@ -39,7 +38,6 @@ from pymongo.errors import (
InvalidOperation,
NetworkTimeout,
)
from pymongo.monitoring import *
from pymongo.operations import *
from pymongo.write_concern import WriteConcern
@ -97,8 +95,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
@async_client_context.require_no_serverless
async def test_batch_splits_if_num_operations_too_large(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
models = []
for _ in range(self.max_write_batch_size + 1):
@ -123,8 +120,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
@async_client_context.require_no_serverless
async def test_batch_splits_if_ops_payload_too_large(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
models = []
num_models = int(self.max_message_size_bytes / self.max_bson_object_size + 1)
@ -157,11 +153,10 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
@async_client_context.require_failCommand_fail_point
async def test_collects_write_concern_errors_across_batches(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(
client = await self.async_rs_or_single_client(
event_listeners=[listener],
retryWrites=False,
)
self.addAsyncCleanup(client.close)
fail_command = {
"configureFailPoint": "failCommand",
@ -200,8 +195,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
@async_client_context.require_no_serverless
async def test_collects_write_errors_across_batches_unordered(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
collection = client.db["coll"]
self.addAsyncCleanup(collection.drop)
@ -231,8 +225,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
@async_client_context.require_no_serverless
async def test_collects_write_errors_across_batches_ordered(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
collection = client.db["coll"]
self.addAsyncCleanup(collection.drop)
@ -262,8 +255,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
@async_client_context.require_no_serverless
async def test_handles_cursor_requiring_getMore(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
collection = client.db["coll"]
self.addAsyncCleanup(collection.drop)
@ -304,8 +296,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
@async_client_context.require_no_standalone
async def test_handles_cursor_requiring_getMore_within_transaction(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
collection = client.db["coll"]
self.addAsyncCleanup(collection.drop)
@ -348,8 +339,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
@async_client_context.require_failCommand_fail_point
async def test_handles_getMore_error(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
collection = client.db["coll"]
self.addAsyncCleanup(collection.drop)
@ -403,8 +393,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
@async_client_context.require_no_serverless
async def test_returns_error_if_unacknowledged_too_large_insert(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
b_repeated = "b" * self.max_bson_object_size
@ -460,8 +449,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
@async_client_context.require_no_serverless
async def test_no_batch_splits_if_new_namespace_is_not_too_large(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
num_models, models = await self._setup_namespace_test_models()
models.append(
@ -492,8 +480,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
@async_client_context.require_no_serverless
async def test_batch_splits_if_new_namespace_is_too_large(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
num_models, models = await self._setup_namespace_test_models()
c_repeated = "c" * 200
@ -530,8 +517,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
@async_client_context.require_version_min(8, 0, 0, -24)
@async_client_context.require_no_serverless
async def test_returns_error_if_no_writes_can_be_added_to_ops(self):
client = await async_rs_or_single_client()
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client()
# Document too large.
b_repeated = "b" * self.max_message_size_bytes
@ -554,8 +540,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
key_vault_namespace="db.coll",
kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}},
)
client = await async_rs_or_single_client(auto_encryption_opts=opts)
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
models = [InsertOne(namespace="db.coll", document={"a": "b"})]
with self.assertRaises(InvalidOperation) as context:
@ -580,7 +565,7 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
async def test_timeout_in_multi_batch_bulk_write(self):
_OVERHEAD = 500
internal_client = await async_rs_or_single_client(timeoutMS=None)
internal_client = await self.async_rs_or_single_client(timeoutMS=None)
self.addAsyncCleanup(internal_client.close)
collection = internal_client.db["coll"]
@ -605,14 +590,13 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
)
listener = OvertCommandListener()
client = await async_rs_or_single_client(
client = await self.async_rs_or_single_client(
event_listeners=[listener],
readConcernLevel="majority",
readPreference="primary",
timeoutMS=2000,
w="majority",
)
self.addAsyncCleanup(client.close)
await client.admin.command("ping") # Init the client first.
with self.assertRaises(ClientBulkWriteException) as context:
await client.bulk_write(models=models)

View File

@ -30,6 +30,7 @@ sys.path[0:0] = [""]
from test import unittest
from test.asynchronous import ( # TODO: fix sync imports in PYTHON-4528
AsyncIntegrationTest,
AsyncUnitTest,
async_client_context,
)
from test.utils import (
@ -37,8 +38,6 @@ from test.utils import (
EventListener,
async_get_pool,
async_is_mongos,
async_rs_or_single_client,
async_single_client,
async_wait_until,
wait_until,
)
@ -82,14 +81,20 @@ from pymongo.write_concern import WriteConcern
_IS_SYNC = False
class TestCollectionNoConnect(unittest.TestCase):
class TestCollectionNoConnect(AsyncUnitTest):
"""Test Collection features on a client that does not connect."""
db: AsyncDatabase
client: AsyncMongoClient
@classmethod
def setUpClass(cls):
cls.db = AsyncMongoClient(connect=False).pymongo_test
async def _setup_class(cls):
cls.client = AsyncMongoClient(connect=False)
cls.db = cls.client.pymongo_test
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
def test_collection(self):
self.assertRaises(TypeError, AsyncCollection, self.db, 5)
@ -1819,8 +1824,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
# Insert enough documents to require more than one batch
await self.db.test.insert_many([{"i": i} for i in range(150)])
client = await async_rs_or_single_client(maxPoolSize=1)
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(maxPoolSize=1)
pool = await async_get_pool(client)
# Make sure the socket is returned after exhaustion.
@ -2100,7 +2104,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
async def test_find_one_and_write_concern(self):
listener = EventListener()
db = (await async_single_client(event_listeners=[listener]))[self.db.name]
db = (await self.async_single_client(event_listeners=[listener]))[self.db.name]
# non-default WriteConcern.
c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0))
# default WriteConcern.

View File

@ -34,7 +34,7 @@ from test.utils import (
AllowListEventListener,
EventListener,
OvertCommandListener,
async_rs_or_single_client,
delay,
ignore_deprecations,
wait_until,
)
@ -45,7 +45,7 @@ from pymongo import ASCENDING, DESCENDING
from pymongo.asynchronous.cursor import AsyncCursor, CursorType
from pymongo.asynchronous.helpers import anext
from pymongo.collation import Collation
from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure
from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure, PyMongoError
from pymongo.operations import _IndexList
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
@ -232,7 +232,7 @@ class TestCursor(AsyncIntegrationTest):
self.assertEqual(90, cursor._max_await_time_ms)
listener = AllowListEventListener("find", "getMore")
coll = (await async_rs_or_single_client(event_listeners=[listener]))[
coll = (await self.async_rs_or_single_client(event_listeners=[listener]))[
self.db.name
].pymongo_test
@ -353,8 +353,7 @@ class TestCursor(AsyncIntegrationTest):
async def test_explain_with_read_concern(self):
# Do not add readConcern level to explain.
listener = AllowListEventListener("explain")
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local"))
self.assertTrue(await coll.find().explain())
started = listener.started_events
@ -1261,8 +1260,7 @@ class TestCursor(AsyncIntegrationTest):
await self.client._process_periodic_tasks()
listener = AllowListEventListener("killCursors")
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
coll = client[self.db.name].test_close_kills_cursors
# Add some test data.
@ -1300,8 +1298,7 @@ class TestCursor(AsyncIntegrationTest):
@async_client_context.require_failCommand_appName
async def test_timeout_kills_cursor_asynchronously(self):
listener = AllowListEventListener("killCursors")
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
coll = client[self.db.name].test_timeout_kills_cursor
# Add some test data.
@ -1358,8 +1355,7 @@ class TestCursor(AsyncIntegrationTest):
async def test_getMore_does_not_send_readPreference(self):
listener = AllowListEventListener("find", "getMore")
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
# We never send primary read preference so override the default.
coll = client[self.db.name].get_collection(
"test", read_preference=ReadPreference.PRIMARY_PREFERRED
@ -1415,6 +1411,18 @@ class TestCursor(AsyncIntegrationTest):
docs = await c.to_list(3)
self.assertEqual(len(docs), 2)
async def test_to_list_csot_applied(self):
client = await self.async_single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
await client.admin.command("ping")
coll = client.pymongo.test
await coll.insert_many([{} for _ in range(5)])
cursor = coll.find({"$where": delay(1)})
with self.assertRaises(PyMongoError) as ctx:
await cursor.to_list()
self.assertTrue(ctx.exception.timeout)
@async_client_context.require_change_streams
async def test_command_cursor_to_list(self):
# Set maxAwaitTimeMS=1 to speed up the test.
@ -1444,6 +1452,25 @@ class TestCursor(AsyncIntegrationTest):
result = await db.test.aggregate([pipeline])
self.assertEqual(len(await result.to_list(1)), 1)
@async_client_context.require_failCommand_blockConnection
async def test_command_cursor_to_list_csot_applied(self):
client = await self.async_single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
await client.admin.command("ping")
coll = client.pymongo.test
await coll.insert_many([{} for _ in range(5)])
fail_command = {
"configureFailPoint": "failCommand",
"mode": {"times": 5},
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 1000},
}
cursor = await coll.aggregate([], batchSize=1)
async with self.fail_point(fail_command):
with self.assertRaises(PyMongoError) as ctx:
await cursor.to_list()
self.assertTrue(ctx.exception.timeout)
class TestRawBatchCursor(AsyncIntegrationTest):
async def test_find_raw(self):
@ -1463,7 +1490,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
await c.insert_many(docs)
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
client = await self.async_rs_or_single_client(event_listeners=[listener])
async with client.start_session() as session:
async with await session.start_transaction():
batches = await (
@ -1493,7 +1520,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
await c.insert_many(docs)
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True)
client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True)
async with self.fail_point(
{"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}}
):
@ -1514,7 +1541,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
await c.insert_many(docs)
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True)
client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True)
db = client[self.db.name]
async with client.start_session(snapshot=True) as session:
await db.test.distinct("x", {}, session=session)
@ -1577,7 +1604,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
async def test_monitoring(self):
listener = EventListener()
client = await async_rs_or_single_client(event_listeners=[listener])
client = await self.async_rs_or_single_client(event_listeners=[listener])
c = client.pymongo_test.test
await c.drop()
await c.insert_many([{"_id": i} for i in range(10)])
@ -1643,7 +1670,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
await c.insert_many(docs)
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
client = await self.async_rs_or_single_client(event_listeners=[listener])
async with client.start_session() as session:
async with await session.start_transaction():
batches = await (
@ -1674,7 +1701,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
await c.insert_many(docs)
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True)
client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True)
async with self.fail_point(
{"mode": {"times": 1}, "data": {"failCommands": ["aggregate"], "closeConnection": True}}
):
@ -1698,7 +1725,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
await c.insert_many(docs)
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True)
client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True)
db = client[self.db.name]
async with client.start_session(snapshot=True) as session:
await db.test.distinct("x", {}, session=session)
@ -1744,7 +1771,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
async def test_monitoring(self):
listener = EventListener()
client = await async_rs_or_single_client(event_listeners=[listener])
client = await self.async_rs_or_single_client(event_listeners=[listener])
c = client.pymongo_test.test
await c.drop()
await c.insert_many([{"_id": i} for i in range(10)])
@ -1788,8 +1815,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
@async_client_context.require_no_mongos
async def test_exhaust_cursor_db_set(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
c = client.pymongo_test.test
await c.delete_many({})
await c.insert_many([{"_id": i} for i in range(3)])

View File

@ -29,7 +29,6 @@ from test.test_custom_types import DECIMAL_CODECOPTS
from test.utils import (
IMPOSSIBLE_WRITE_CONCERN,
OvertCommandListener,
async_rs_or_single_client,
async_wait_until,
)
@ -208,7 +207,7 @@ class TestDatabase(AsyncIntegrationTest):
async def test_list_collection_names_filter(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
client = await self.async_rs_or_single_client(event_listeners=[listener])
db = client[self.db.name]
await db.capped.drop()
await db.create_collection("capped", capped=True, size=4096)
@ -235,8 +234,7 @@ class TestDatabase(AsyncIntegrationTest):
async def test_check_exists(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(event_listeners=[listener])
db = client[self.db.name]
await db.drop_collection("unique")
await db.create_collection("unique", check_exists=True)
@ -326,7 +324,7 @@ class TestDatabase(AsyncIntegrationTest):
await self.client.drop_database("pymongo_test")
async def test_list_collection_names_single_socket(self):
client = await async_rs_or_single_client(maxPoolSize=1)
client = await self.async_rs_or_single_client(maxPoolSize=1)
await client.drop_database("test_collection_names_single_socket")
db = client.test_collection_names_single_socket
for i in range(200):

View File

@ -31,7 +31,7 @@ import warnings
from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context
from test.asynchronous.test_bulk import AsyncBulkTestBase
from threading import Thread
from typing import Any, Dict, Mapping
from typing import Any, Dict, Mapping, Optional
import pytest
@ -44,6 +44,8 @@ sys.path[0:0] = [""]
from test import (
unittest,
)
from test.asynchronous.test_bulk import AsyncBulkTestBase
from test.asynchronous.utils_spec_runner import AsyncSpecRunner
from test.helpers import (
AWS_CREDS,
AZURE_CREDS,
@ -59,12 +61,10 @@ from test.utils import (
OvertCommandListener,
SpecTestCreator,
TopologyEventListener,
async_rs_or_single_client,
async_wait_until,
camel_to_snake_args,
is_greenthread_patched,
)
from test.utils_spec_runner import SpecRunner
from bson import DatetimeMS, Decimal128, encode, json_util
from bson.binary import UUID_SUBTYPE, Binary, UuidRepresentation
@ -109,13 +109,12 @@ class TestAutoEncryptionOpts(AsyncPyMongoTestCase):
@unittest.skipUnless(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is not installed")
async def test_crypt_shared(self):
# Test that we can pick up crypt_shared lib automatically
client = AsyncMongoClient(
self.simple_client(
auto_encryption_opts=AutoEncryptionOpts(
KMS_PROVIDERS, "keyvault.datakeys", crypt_shared_lib_required=True
),
connect=False,
)
self.addAsyncCleanup(client.aclose)
@unittest.skipIf(_HAVE_PYMONGOCRYPT, "pymongocrypt is installed")
def test_init_requires_pymongocrypt(self):
@ -196,19 +195,16 @@ class TestAutoEncryptionOpts(AsyncPyMongoTestCase):
class TestClientOptions(AsyncPyMongoTestCase):
async def test_default(self):
client = AsyncMongoClient(connect=False)
self.addAsyncCleanup(client.aclose)
client = self.simple_client(connect=False)
self.assertEqual(get_client_opts(client).auto_encryption_opts, None)
client = AsyncMongoClient(auto_encryption_opts=None, connect=False)
self.addAsyncCleanup(client.aclose)
client = self.simple_client(auto_encryption_opts=None, connect=False)
self.assertEqual(get_client_opts(client).auto_encryption_opts, None)
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
async def test_kwargs(self):
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
client = AsyncMongoClient(auto_encryption_opts=opts, connect=False)
self.addAsyncCleanup(client.aclose)
client = self.simple_client(auto_encryption_opts=opts, connect=False)
self.assertEqual(get_client_opts(client).auto_encryption_opts, opts)
@ -229,6 +225,34 @@ class AsyncEncryptionIntegrationTest(AsyncIntegrationTest):
self.assertIsInstance(val, Binary)
self.assertEqual(val.subtype, UUID_SUBTYPE)
def create_client_encryption(
self,
kms_providers: Mapping[str, Any],
key_vault_namespace: str,
key_vault_client: AsyncMongoClient,
codec_options: CodecOptions,
kms_tls_options: Optional[Mapping[str, Any]] = None,
):
client_encryption = AsyncClientEncryption(
kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options
)
self.addAsyncCleanup(client_encryption.close)
return client_encryption
@classmethod
def unmanaged_create_client_encryption(
cls,
kms_providers: Mapping[str, Any],
key_vault_namespace: str,
key_vault_client: AsyncMongoClient,
codec_options: CodecOptions,
kms_tls_options: Optional[Mapping[str, Any]] = None,
):
client_encryption = AsyncClientEncryption(
kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options
)
return client_encryption
# Location of JSON test files.
if _IS_SYNC:
@ -260,8 +284,7 @@ def bson_data(*paths):
class TestClientSimple(AsyncEncryptionIntegrationTest):
async def _test_auto_encrypt(self, opts):
client = await async_rs_or_single_client(auto_encryption_opts=opts)
self.addAsyncCleanup(client.aclose)
client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
# Create the encrypted field's data key.
key_vault = await create_key_vault(
@ -342,8 +365,7 @@ class TestClientSimple(AsyncEncryptionIntegrationTest):
async def test_use_after_close(self):
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
client = await async_rs_or_single_client(auto_encryption_opts=opts)
self.addAsyncCleanup(client.aclose)
client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
await client.admin.command("ping")
await client.aclose()
@ -358,10 +380,10 @@ class TestClientSimple(AsyncEncryptionIntegrationTest):
is_greenthread_patched(),
"gevent and eventlet do not support POSIX-style forking.",
)
@async_client_context.require_sync
async def test_fork(self):
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
client = await async_rs_or_single_client(auto_encryption_opts=opts)
self.addAsyncCleanup(client.aclose)
client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
async def target():
with warnings.catch_warnings():
@ -375,8 +397,7 @@ class TestClientSimple(AsyncEncryptionIntegrationTest):
class TestEncryptedBulkWrite(AsyncBulkTestBase, AsyncEncryptionIntegrationTest):
async def test_upsert_uuid_standard_encrypt(self):
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
client = await async_rs_or_single_client(auto_encryption_opts=opts)
self.addAsyncCleanup(client.aclose)
client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
encrypted_coll = client.pymongo_test.test
@ -416,8 +437,7 @@ class TestClientMaxWireVersion(AsyncIntegrationTest):
@async_client_context.require_version_max(4, 0, 99)
async def test_raise_max_wire_version_error(self):
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
client = await async_rs_or_single_client(auto_encryption_opts=opts)
self.addAsyncCleanup(client.aclose)
client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
msg = "Auto-encryption requires a minimum MongoDB version of 4.2"
with self.assertRaisesRegex(ConfigurationError, msg):
await client.test.test.insert_one({})
@ -430,8 +450,7 @@ class TestClientMaxWireVersion(AsyncIntegrationTest):
async def test_raise_unsupported_error(self):
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
client = await async_rs_or_single_client(auto_encryption_opts=opts)
self.addAsyncCleanup(client.aclose)
client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
msg = "find_raw_batches does not support auto encryption"
with self.assertRaisesRegex(InvalidOperation, msg):
await client.test.test.find_raw_batches({})
@ -450,10 +469,9 @@ class TestClientMaxWireVersion(AsyncIntegrationTest):
class TestExplicitSimple(AsyncEncryptionIntegrationTest):
async def test_encrypt_decrypt(self):
client_encryption = AsyncClientEncryption(
client_encryption = self.create_client_encryption(
KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS
)
self.addAsyncCleanup(client_encryption.close)
# Use standard UUID representation.
key_vault = async_client_context.client.keyvault.get_collection(
"datakeys", codec_options=OPTS
@ -495,10 +513,9 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest):
self.assertEqual(decrypted_ssn, doc["ssn"])
async def test_validation(self):
client_encryption = AsyncClientEncryption(
client_encryption = self.create_client_encryption(
KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS
)
self.addAsyncCleanup(client_encryption.close)
msg = "value to decrypt must be a bson.binary.Binary with subtype 6"
with self.assertRaisesRegex(TypeError, msg):
@ -512,10 +529,9 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest):
await client_encryption.encrypt("str", algo, key_id=Binary(b"123"))
async def test_bson_errors(self):
client_encryption = AsyncClientEncryption(
client_encryption = self.create_client_encryption(
KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS
)
self.addAsyncCleanup(client_encryption.close)
# Attempt to encrypt an unencodable object.
unencodable_value = object()
@ -528,7 +544,7 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest):
async def test_codec_options(self):
with self.assertRaisesRegex(TypeError, "codec_options must be"):
AsyncClientEncryption(
self.create_client_encryption(
KMS_PROVIDERS,
"keyvault.datakeys",
async_client_context.client,
@ -536,10 +552,9 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest):
)
opts = CodecOptions(uuid_representation=UuidRepresentation.JAVA_LEGACY)
client_encryption_legacy = AsyncClientEncryption(
client_encryption_legacy = self.create_client_encryption(
KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, opts
)
self.addAsyncCleanup(client_encryption_legacy.close)
# Create the encrypted field's data key.
key_id = await client_encryption_legacy.create_data_key("local")
@ -554,10 +569,9 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest):
# Encrypt the same UUID with STANDARD codec options.
opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
client_encryption = AsyncClientEncryption(
client_encryption = self.create_client_encryption(
KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, opts
)
self.addAsyncCleanup(client_encryption.close)
encrypted_standard = await client_encryption.encrypt(
value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id
)
@ -573,7 +587,7 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest):
self.assertNotEqual(await client_encryption.decrypt(encrypted_legacy), value)
async def test_close(self):
client_encryption = AsyncClientEncryption(
client_encryption = self.create_client_encryption(
KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS
)
await client_encryption.close()
@ -589,7 +603,7 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest):
await client_encryption.decrypt(Binary(b"", 6))
async def test_with_statement(self):
async with AsyncClientEncryption(
async with self.create_client_encryption(
KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS
) as client_encryption:
pass
@ -613,7 +627,7 @@ KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PE
if _IS_SYNC:
# TODO: Add asynchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700)
class TestSpec(SpecRunner):
class TestSpec(AsyncSpecRunner):
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
def setUpClass(cls):
@ -811,7 +825,7 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
async def _setup_class(cls):
await super()._setup_class()
cls.listener = OvertCommandListener()
cls.client = await async_rs_or_single_client(event_listeners=[cls.listener])
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
await cls.client.db.coll.drop()
cls.vault = await create_key_vault(cls.client.keyvault.datakeys)
@ -833,10 +847,10 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
opts = AutoEncryptionOpts(
cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS
)
cls.client_encrypted = await async_rs_or_single_client(
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client(
auto_encryption_opts=opts, uuidRepresentation="standard"
)
cls.client_encryption = AsyncClientEncryption(
cls.client_encryption = cls.unmanaged_create_client_encryption(
cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS
)
@ -923,10 +937,9 @@ class TestExternalKeyVault(AsyncEncryptionIntegrationTest):
# Configure the encrypted field via the local schema_map option.
schemas = {"db.coll": json_data("external", "external-schema.json")}
if with_external_key_vault:
key_vault_client = await async_rs_or_single_client(
key_vault_client = await self.async_rs_or_single_client(
username="fake-user", password="fake-pwd"
)
self.addAsyncCleanup(key_vault_client.close)
else:
key_vault_client = async_client_context.client
opts = AutoEncryptionOpts(
@ -936,15 +949,13 @@ class TestExternalKeyVault(AsyncEncryptionIntegrationTest):
key_vault_client=key_vault_client,
)
client_encrypted = await async_rs_or_single_client(
client_encrypted = await self.async_rs_or_single_client(
auto_encryption_opts=opts, uuidRepresentation="standard"
)
self.addAsyncCleanup(client_encrypted.close)
client_encryption = AsyncClientEncryption(
client_encryption = self.create_client_encryption(
self.kms_providers(), "keyvault.datakeys", key_vault_client, OPTS
)
self.addAsyncCleanup(client_encryption.close)
if with_external_key_vault:
# Authentication error.
@ -990,10 +1001,9 @@ class TestViews(AsyncEncryptionIntegrationTest):
self.addAsyncCleanup(self.client.db.view.drop)
opts = AutoEncryptionOpts(self.kms_providers(), "keyvault.datakeys")
client_encrypted = await async_rs_or_single_client(
client_encrypted = await self.async_rs_or_single_client(
auto_encryption_opts=opts, uuidRepresentation="standard"
)
self.addAsyncCleanup(client_encrypted.aclose)
with self.assertRaisesRegex(EncryptionError, "cannot auto encrypt a view"):
await client_encrypted.db.view.insert_one({})
@ -1050,17 +1060,15 @@ class TestCorpus(AsyncEncryptionIntegrationTest):
)
self.addAsyncCleanup(vault.drop)
client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts)
self.addAsyncCleanup(client_encrypted.close)
client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts)
client_encryption = AsyncClientEncryption(
client_encryption = self.create_client_encryption(
self.kms_providers(),
"keyvault.datakeys",
async_client_context.client,
OPTS,
kms_tls_options=KMS_TLS_OPTS,
)
self.addAsyncCleanup(client_encryption.close)
corpus = self.fix_up_curpus(json_data("corpus", "corpus.json"))
corpus_copied: SON = SON()
@ -1203,7 +1211,7 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
cls.listener = OvertCommandListener()
cls.client_encrypted = await async_rs_or_single_client(
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client(
auto_encryption_opts=opts, event_listeners=[cls.listener]
)
cls.coll_encrypted = cls.client_encrypted.db.coll
@ -1291,7 +1299,7 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
"gcp": GCP_CREDS,
"kmip": KMIP_CREDS,
}
self.client_encryption = AsyncClientEncryption(
self.client_encryption = self.create_client_encryption(
kms_providers=kms_providers,
key_vault_namespace="keyvault.datakeys",
key_vault_client=async_client_context.client,
@ -1303,7 +1311,7 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
kms_providers_invalid["azure"]["identityPlatformEndpoint"] = "doesnotexist.invalid:443"
kms_providers_invalid["gcp"]["endpoint"] = "doesnotexist.invalid:443"
kms_providers_invalid["kmip"]["endpoint"] = "doesnotexist.local:5698"
self.client_encryption_invalid = AsyncClientEncryption(
self.client_encryption_invalid = self.create_client_encryption(
kms_providers=kms_providers_invalid,
key_vault_namespace="keyvault.datakeys",
key_vault_client=async_client_context.client,
@ -1484,7 +1492,7 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
await self.client_encryption.create_data_key("kmip", key)
class AzureGCPEncryptionTestMixin:
class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
DEK = None
KMS_PROVIDER_MAP = None
KEYVAULT_DB = "keyvault"
@ -1496,7 +1504,7 @@ class AzureGCPEncryptionTestMixin:
await create_key_vault(keyvault, self.DEK)
async def _test_explicit(self, expectation):
client_encryption = AsyncClientEncryption(
client_encryption = self.create_client_encryption(
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
async_client_context.client,
@ -1525,7 +1533,7 @@ class AzureGCPEncryptionTestMixin:
)
insert_listener = AllowListEventListener("insert")
client = await async_rs_or_single_client(
client = await self.async_rs_or_single_client(
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
)
self.addAsyncCleanup(client.aclose)
@ -1604,19 +1612,17 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationT
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#deadlock-tests
class TestDeadlockProse(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self):
self.client_test = await async_rs_or_single_client(
self.client_test = await self.async_rs_or_single_client(
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
)
self.addAsyncCleanup(self.client_test.aclose)
self.client_keyvault_listener = OvertCommandListener()
self.client_keyvault = await async_rs_or_single_client(
self.client_keyvault = await self.async_rs_or_single_client(
maxPoolSize=1,
readConcernLevel="majority",
w="majority",
event_listeners=[self.client_keyvault_listener],
)
self.addAsyncCleanup(self.client_keyvault.aclose)
await self.client_test.keyvault.datakeys.drop()
await self.client_test.db.coll.drop()
@ -1629,7 +1635,7 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
codec_options=OPTS,
)
client_encryption = AsyncClientEncryption(
client_encryption = self.create_client_encryption(
kms_providers={"local": {"key": LOCAL_MASTER_KEY}},
key_vault_namespace="keyvault.datakeys",
key_vault_client=self.client_test,
@ -1645,7 +1651,7 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
self.optargs = ({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
async def _run_test(self, max_pool_size, auto_encryption_opts):
client_encrypted = await async_rs_or_single_client(
client_encrypted = await self.async_rs_or_single_client(
readConcernLevel="majority",
w="majority",
maxPoolSize=max_pool_size,
@ -1663,8 +1669,6 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
result = await client_encrypted.db.coll.find_one({"_id": 0})
self.assertEqual(result, {"_id": 0, "encrypted": "string0"})
self.addAsyncCleanup(client_encrypted.close)
async def test_case_1(self):
await self._run_test(
max_pool_size=1,
@ -1840,7 +1844,7 @@ class TestDecryptProse(AsyncEncryptionIntegrationTest):
await create_key_vault(self.client.keyvault.datakeys)
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
self.client_encryption = AsyncClientEncryption(
self.client_encryption = self.create_client_encryption(
kms_providers_map, "keyvault.datakeys", self.client, CodecOptions()
)
keyID = await self.client_encryption.create_data_key("local")
@ -1855,10 +1859,9 @@ class TestDecryptProse(AsyncEncryptionIntegrationTest):
key_vault_namespace="keyvault.datakeys", kms_providers=kms_providers_map
)
self.listener = AllowListEventListener("aggregate")
self.encrypted_client = await async_rs_or_single_client(
self.encrypted_client = await self.async_rs_or_single_client(
auto_encryption_opts=opts, retryReads=False, event_listeners=[self.listener]
)
self.addAsyncCleanup(self.encrypted_client.close)
async def test_01_command_error(self):
async with self.fail_point(
@ -1935,8 +1938,7 @@ class TestBypassSpawningMongocryptdProse(AsyncEncryptionIntegrationTest):
"--port=27027",
],
)
client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts)
self.addAsyncCleanup(client_encrypted.close)
client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts)
with self.assertRaisesRegex(EncryptionError, "Timeout"):
await client_encrypted.db.coll.insert_one({"encrypted": "test"})
@ -1950,11 +1952,10 @@ class TestBypassSpawningMongocryptdProse(AsyncEncryptionIntegrationTest):
"--port=27027",
],
)
client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts)
self.addAsyncCleanup(client_encrypted.aclose)
client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts)
await client_encrypted.db.coll.insert_one({"unencrypted": "test"})
# Validate that mongocryptd was not spawned:
mongocryptd_client = AsyncMongoClient(
mongocryptd_client = self.simple_client(
"mongodb://localhost:27027/?serverSelectionTimeoutMS=500"
)
with self.assertRaises(ServerSelectionTimeoutError):
@ -1978,15 +1979,13 @@ class TestBypassSpawningMongocryptdProse(AsyncEncryptionIntegrationTest):
],
crypt_shared_lib_required=True,
)
client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts)
self.addAsyncCleanup(client_encrypted.aclose)
client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts)
await client_encrypted.db.coll.drop()
await client_encrypted.db.coll.insert_one({"encrypted": "test"})
self.assertEncrypted((await async_client_context.client.db.coll.find_one({}))["encrypted"])
no_mongocryptd_client = AsyncMongoClient(
no_mongocryptd_client = self.simple_client(
host="mongodb://localhost:47021/db?serverSelectionTimeoutMS=1000"
)
self.addAsyncCleanup(no_mongocryptd_client.aclose)
with self.assertRaises(ServerSelectionTimeoutError):
await no_mongocryptd_client.db.command("ping")
@ -2020,8 +2019,7 @@ class TestBypassSpawningMongocryptdProse(AsyncEncryptionIntegrationTest):
mongocryptd_uri="mongodb://localhost:47021",
crypt_shared_lib_required=False,
)
client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts)
self.addAsyncCleanup(client_encrypted.aclose)
client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts)
await client_encrypted.db.coll.drop()
await client_encrypted.db.coll.insert_one({"encrypted": "test"})
server.shutdown()
@ -2035,10 +2033,9 @@ class TestKmsTLSProse(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self):
await super().asyncSetUp()
self.patch_system_certs(CA_PEM)
self.client_encrypted = AsyncClientEncryption(
self.client_encrypted = self.create_client_encryption(
{"aws": AWS_CREDS}, "keyvault.datakeys", self.client, OPTS
)
self.addAsyncCleanup(self.client_encrypted.close)
async def test_invalid_kms_certificate_expired(self):
key = {
@ -2083,36 +2080,32 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest):
"gcp": {"tlsCAFile": CA_PEM},
"kmip": {"tlsCAFile": CA_PEM},
}
self.client_encryption_no_client_cert = AsyncClientEncryption(
self.client_encryption_no_client_cert = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only
)
self.addAsyncCleanup(self.client_encryption_no_client_cert.close)
# 2, same providers as above but with tlsCertificateKeyFile.
kms_tls_opts = copy.deepcopy(kms_tls_opts_ca_only)
for p in kms_tls_opts:
kms_tls_opts[p]["tlsCertificateKeyFile"] = CLIENT_PEM
self.client_encryption_with_tls = AsyncClientEncryption(
self.client_encryption_with_tls = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts
)
self.addAsyncCleanup(self.client_encryption_with_tls.close)
# 3, update endpoints to expired host.
providers: dict = copy.deepcopy(providers)
providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9000"
providers["gcp"]["endpoint"] = "127.0.0.1:9000"
providers["kmip"]["endpoint"] = "127.0.0.1:9000"
self.client_encryption_expired = AsyncClientEncryption(
self.client_encryption_expired = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only
)
self.addAsyncCleanup(self.client_encryption_expired.close)
# 3, update endpoints to invalid host.
providers: dict = copy.deepcopy(providers)
providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9001"
providers["gcp"]["endpoint"] = "127.0.0.1:9001"
providers["kmip"]["endpoint"] = "127.0.0.1:9001"
self.client_encryption_invalid_hostname = AsyncClientEncryption(
self.client_encryption_invalid_hostname = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only
)
self.addAsyncCleanup(self.client_encryption_invalid_hostname.close)
# Errors when client has no cert, some examples:
# [SSL: TLSV13_ALERT_CERTIFICATE_REQUIRED] tlsv13 alert certificate required (_ssl.c:2623)
self.cert_error = (
@ -2150,7 +2143,7 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest):
"gcp:with_tls": with_cert,
"kmip:with_tls": with_cert,
}
self.client_encryption_with_names = AsyncClientEncryption(
self.client_encryption_with_names = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_4
)
@ -2232,10 +2225,9 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest):
async def test_05_tlsDisableOCSPEndpointCheck_is_permitted(self):
providers = {"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}}
options = {"aws": {"tlsDisableOCSPEndpointCheck": True}}
encryption = AsyncClientEncryption(
encryption = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options
)
self.addAsyncCleanup(encryption.close)
ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"]
if not hasattr(ctx, "check_ocsp_endpoint"):
raise self.skipTest("OCSP not enabled")
@ -2285,7 +2277,7 @@ class TestUniqueIndexOnKeyAltNamesProse(AsyncEncryptionIntegrationTest):
self.client = async_client_context.client
await create_key_vault(self.client.keyvault.datakeys)
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
self.client_encryption = AsyncClientEncryption(
self.client_encryption = self.create_client_encryption(
kms_providers_map, "keyvault.datakeys", self.client, CodecOptions()
)
self.def_key_id = await self.client_encryption.create_data_key(
@ -2327,17 +2319,15 @@ class TestExplicitQueryableEncryption(AsyncEncryptionIntegrationTest):
key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document)
self.addCleanup(key_vault.drop)
self.key_vault_client = self.client
self.client_encryption = AsyncClientEncryption(
self.client_encryption = self.create_client_encryption(
{"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS
)
self.addAsyncCleanup(self.client_encryption.close)
opts = AutoEncryptionOpts(
{"local": {"key": LOCAL_MASTER_KEY}},
key_vault.full_name,
bypass_query_analysis=True,
)
self.encrypted_client = await async_rs_or_single_client(auto_encryption_opts=opts)
self.addAsyncCleanup(self.encrypted_client.aclose)
self.encrypted_client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
async def test_01_insert_encrypted_indexed_and_find(self):
val = "encrypted indexed value"
@ -2464,14 +2454,13 @@ class TestRewrapWithSeparateClientEncryption(AsyncEncryptionIntegrationTest):
await self.client.keyvault.drop_collection("datakeys")
# Step 2. Create a ``AsyncClientEncryption`` object named ``client_encryption1``
client_encryption1 = AsyncClientEncryption(
client_encryption1 = self.create_client_encryption(
key_vault_client=self.client,
key_vault_namespace="keyvault.datakeys",
kms_providers=ALL_KMS_PROVIDERS,
kms_tls_options=KMS_TLS_OPTS,
codec_options=OPTS,
)
self.addAsyncCleanup(client_encryption1.close)
# Step 3. Call ``client_encryption1.create_data_key`` with ``src_provider``.
key_id = await client_encryption1.create_data_key(
@ -2484,16 +2473,14 @@ class TestRewrapWithSeparateClientEncryption(AsyncEncryptionIntegrationTest):
)
# Step 5. Create a ``AsyncClientEncryption`` object named ``client_encryption2``
client2 = await async_rs_or_single_client()
self.addAsyncCleanup(client2.aclose)
client_encryption2 = AsyncClientEncryption(
client2 = await self.async_rs_or_single_client()
client_encryption2 = self.create_client_encryption(
key_vault_client=client2,
key_vault_namespace="keyvault.datakeys",
kms_providers=ALL_KMS_PROVIDERS,
kms_tls_options=KMS_TLS_OPTS,
codec_options=OPTS,
)
self.addAsyncCleanup(client_encryption2.close)
# Step 6. Call ``client_encryption2.rewrap_many_data_key`` with an empty ``filter``.
rewrap_many_data_key_result = await client_encryption2.rewrap_many_data_key(
@ -2528,7 +2515,7 @@ class TestOnDemandAWSCredentials(AsyncEncryptionIntegrationTest):
@unittest.skipIf(any(AWS_CREDS.values()), "AWS environment credentials are set")
async def test_01_failure(self):
self.client_encryption = AsyncClientEncryption(
self.client_encryption = self.create_client_encryption(
kms_providers={"aws": {}},
key_vault_namespace="keyvault.datakeys",
key_vault_client=async_client_context.client,
@ -2539,7 +2526,7 @@ class TestOnDemandAWSCredentials(AsyncEncryptionIntegrationTest):
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
async def test_02_success(self):
self.client_encryption = AsyncClientEncryption(
self.client_encryption = self.create_client_encryption(
kms_providers={"aws": {}},
key_vault_namespace="keyvault.datakeys",
key_vault_client=async_client_context.client,
@ -2559,8 +2546,7 @@ class TestQueryableEncryptionDocsExample(AsyncEncryptionIntegrationTest):
# AsyncMongoClient to use in testing that handles auth/tls/etc,
# and cleanup.
async def AsyncMongoClient(**kwargs):
c = await async_rs_or_single_client(**kwargs)
self.addAsyncCleanup(c.aclose)
c = await self.async_rs_or_single_client(**kwargs)
return c
# Drop data from prior test runs.
@ -2571,7 +2557,7 @@ class TestQueryableEncryptionDocsExample(AsyncEncryptionIntegrationTest):
# Create two data keys.
key_vault_client = await AsyncMongoClient()
client_encryption = AsyncClientEncryption(
client_encryption = self.create_client_encryption(
kms_providers_map, "keyvault.datakeys", key_vault_client, CodecOptions()
)
key1_id = await client_encryption.create_data_key("local")
@ -2652,18 +2638,16 @@ class TestRangeQueryProse(AsyncEncryptionIntegrationTest):
key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document)
self.addCleanup(key_vault.drop)
self.key_vault_client = self.client
self.client_encryption = AsyncClientEncryption(
self.client_encryption = self.create_client_encryption(
{"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS
)
self.addAsyncCleanup(self.client_encryption.close)
opts = AutoEncryptionOpts(
{"local": {"key": LOCAL_MASTER_KEY}},
key_vault.full_name,
bypass_query_analysis=True,
)
self.encrypted_client = await async_rs_or_single_client(auto_encryption_opts=opts)
self.encrypted_client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
self.db = self.encrypted_client.db
self.addAsyncCleanup(self.encrypted_client.aclose)
async def run_expression_find(
self, name, expression, expected_elems, range_opts, use_expr=False, key_id=None
@ -2860,10 +2844,9 @@ class TestRangeQueryDefaultsProse(AsyncEncryptionIntegrationTest):
await super().asyncSetUp()
await self.client.drop_database(self.db)
self.key_vault_client = self.client
self.client_encryption = AsyncClientEncryption(
self.client_encryption = self.create_client_encryption(
{"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys", self.key_vault_client, OPTS
)
self.addAsyncCleanup(self.client_encryption.close)
self.key_id = await self.client_encryption.create_data_key("local")
opts = RangeOpts(min=0, max=1000)
self.payload_defaults = await self.client_encryption.encrypt(
@ -2896,13 +2879,12 @@ class TestAutomaticDecryptionKeys(AsyncEncryptionIntegrationTest):
await self.client.drop_database(self.db)
self.key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document)
self.addAsyncCleanup(self.key_vault.drop)
self.client_encryption = AsyncClientEncryption(
self.client_encryption = self.create_client_encryption(
{"local": {"key": LOCAL_MASTER_KEY}},
self.key_vault.full_name,
self.client,
OPTS,
)
self.addAsyncCleanup(self.client_encryption.close)
async def test_01_simple_create(self):
coll, _ = await self.client_encryption.create_encrypted_collection(
@ -3118,10 +3100,9 @@ class TestNoSessionsSupport(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self) -> None:
self.listener = OvertCommandListener()
self.mongocryptd_client = AsyncMongoClient(
self.mongocryptd_client = self.simple_client(
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]
)
self.addAsyncCleanup(self.mongocryptd_client.aclose)
hello = await self.mongocryptd_client.db.command("hello")
self.assertNotIn("logicalSessionTimeoutMinutes", hello)

View File

@ -33,7 +33,7 @@ from pymongo.asynchronous.database import AsyncDatabase
sys.path[0:0] = [""]
from test.utils import EventListener, async_rs_or_single_client
from test.utils import EventListener
from bson.objectid import ObjectId
from gridfs.asynchronous.grid_file import (
@ -792,7 +792,7 @@ Bye"""
await outfile.readchunk()
async def test_grid_in_lazy_connect(self):
client = AsyncMongoClient("badhost", connect=False, serverSelectionTimeoutMS=10)
client = self.simple_client("badhost", connect=False, serverSelectionTimeoutMS=10)
fs = client.db.fs
infile = AsyncGridIn(fs, file_id=-1, chunk_size=1)
with self.assertRaises(ServerSelectionTimeoutError):
@ -803,7 +803,7 @@ Bye"""
async def test_unacknowledged(self):
# w=0 is prohibited.
with self.assertRaises(ConfigurationError):
AsyncGridIn((await async_rs_or_single_client(w=0)).pymongo_test.fs)
AsyncGridIn((await self.async_rs_or_single_client(w=0)).pymongo_test.fs)
async def test_survive_cursor_not_found(self):
# By default the find command returns 101 documents in the first batch.
@ -811,7 +811,7 @@ Bye"""
chunk_size = 1024
data = b"d" * (102 * chunk_size)
listener = EventListener()
client = await async_rs_or_single_client(event_listeners=[listener])
client = await self.async_rs_or_single_client(event_listeners=[listener])
db = client.pymongo_test
async with AsyncGridIn(db.fs, chunk_size=chunk_size) as infile:
await infile.write(data)

View File

@ -16,7 +16,6 @@ from __future__ import annotations
import os
from test import unittest
from test.asynchronous import AsyncIntegrationTest
from test.utils import async_single_client
from unittest.mock import patch
from bson import json_util
@ -86,7 +85,7 @@ class TestLogger(AsyncIntegrationTest):
self.assertEqual(last_3_bytes, str_to_repeat)
async def test_logging_without_listeners(self):
c = await async_single_client()
c = await self.async_single_client()
self.assertEqual(len(c._event_listeners.event_listeners()), 0)
with self.assertLogs("pymongo.connection", level="DEBUG") as cm:
await c.db.test.insert_one({"x": "1"})

View File

@ -31,8 +31,6 @@ from test.asynchronous import (
)
from test.utils import (
EventListener,
async_rs_or_single_client,
async_single_client,
async_wait_until,
)
@ -57,7 +55,7 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest):
async def _setup_class(cls):
await super()._setup_class()
cls.listener = EventListener()
cls.client = await async_rs_or_single_client(
cls.client = await cls.unmanaged_async_rs_or_single_client(
event_listeners=[cls.listener], retryWrites=False
)
@ -407,7 +405,7 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest):
@async_client_context.require_secondaries_count(1)
async def test_not_primary_error(self):
address = next(iter(await async_client_context.client.secondaries))
client = await async_single_client(*address, event_listeners=[self.listener])
client = await self.async_single_client(*address, event_listeners=[self.listener])
# Clear authentication command results from the listener.
await client.admin.command("ping")
self.listener.reset()
@ -1146,7 +1144,7 @@ class AsyncTestGlobalListener(AsyncIntegrationTest):
# We plan to call register(), which internally modifies _LISTENERS.
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
monitoring.register(cls.listener)
cls.client = await async_single_client()
cls.client = await cls.unmanaged_async_single_client()
# Get one (authenticated) socket in the pool.
await cls.client.pymongo_test.command("ping")

View File

@ -36,9 +36,7 @@ from test.asynchronous import (
from test.utils import (
EventListener,
ExceptionCatchingThread,
async_rs_or_single_client,
async_wait_until,
rs_or_single_client,
wait_until,
)
@ -90,7 +88,7 @@ class TestSession(AsyncIntegrationTest):
await super()._setup_class()
# Create a second client so we can make sure clients cannot share
# sessions.
cls.client2 = await async_rs_or_single_client()
cls.client2 = await cls.unmanaged_async_rs_or_single_client()
# Redact no commands, so we can test user-admin commands have "lsid".
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
@ -105,7 +103,7 @@ class TestSession(AsyncIntegrationTest):
async def asyncSetUp(self):
self.listener = SessionTestListener()
self.session_checker_listener = SessionTestListener()
self.client = await async_rs_or_single_client(
self.client = await self.async_rs_or_single_client(
event_listeners=[self.listener, self.session_checker_listener]
)
self.addAsyncCleanup(self.client.close)
@ -202,7 +200,7 @@ class TestSession(AsyncIntegrationTest):
failures = 0
for _ in range(5):
listener = EventListener()
client = async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
client = self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
cursor = client.db.test.find({})
ops: List[Tuple[Callable, List[Any]]] = [
(client.db.test.find_one, [{"_id": 1}]),
@ -285,7 +283,7 @@ class TestSession(AsyncIntegrationTest):
async def test_end_sessions(self):
# Use a new client so that the tearDown hook does not error.
listener = SessionTestListener()
client = await async_rs_or_single_client(event_listeners=[listener])
client = await self.async_rs_or_single_client(event_listeners=[listener])
# Start many sessions.
sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)]
for s in sessions:
@ -789,8 +787,7 @@ class TestSession(AsyncIntegrationTest):
async def test_unacknowledged_writes(self):
# Ensure the collection exists.
await self.client.pymongo_test.test_unacked_writes.insert_one({})
client = await async_rs_or_single_client(w=0, event_listeners=[self.listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_or_single_client(w=0, event_listeners=[self.listener])
db = client.pymongo_test
coll = db.test_unacked_writes
ops: list = [
@ -838,7 +835,7 @@ class TestCausalConsistency(AsyncUnitTest):
@classmethod
async def _setup_class(cls):
cls.listener = SessionTestListener()
cls.client = await async_rs_or_single_client(event_listeners=[cls.listener])
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
@classmethod
async def _tearDown_class(cls):
@ -1153,10 +1150,9 @@ class TestClusterTime(AsyncIntegrationTest):
async def test_cluster_time(self):
listener = SessionTestListener()
# Prevent heartbeats from updating $clusterTime between operations.
client = await async_rs_or_single_client(
client = await self.async_rs_or_single_client(
event_listeners=[listener], heartbeatFrequencyMS=999999
)
self.addAsyncCleanup(client.close)
collection = client.pymongo_test.collection
# Prepare for tests of find() and aggregate().
await collection.insert_many([{} for _ in range(10)])

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import sys
from io import BytesIO
from test.asynchronous.utils_spec_runner import AsyncSpecRunner
from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket
@ -25,8 +26,6 @@ sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.utils import (
OvertCommandListener,
async_rs_client,
async_single_client,
wait_until,
)
from typing import List
@ -59,7 +58,18 @@ _IS_SYNC = False
UNPIN_TEST_MAX_ATTEMPTS = 50
class TestTransactions(AsyncIntegrationTest):
class AsyncTransactionsBase(AsyncSpecRunner):
def maybe_skip_scenario(self, test):
super().maybe_skip_scenario(test)
if (
"secondary" in self.id()
and not async_client_context.is_mongos
and not async_client_context.has_secondaries
):
raise unittest.SkipTest("No secondaries")
class TestTransactions(AsyncTransactionsBase):
RUN_ON_SERVERLESS = True
@async_client_context.require_transactions
@ -92,8 +102,7 @@ class TestTransactions(AsyncIntegrationTest):
@async_client_context.require_transactions
async def test_transaction_write_concern_override(self):
"""Test txn overrides Client/Database/Collection write_concern."""
client = await async_rs_client(w=0)
self.addAsyncCleanup(client.close)
client = await self.async_rs_client(w=0)
db = client.test
coll = db.test
await coll.insert_one({})
@ -150,12 +159,13 @@ class TestTransactions(AsyncIntegrationTest):
async def test_unpin_for_next_transaction(self):
# Increase localThresholdMS and wait until both nodes are discovered
# to avoid false positives.
client = await async_rs_client(async_client_context.mongos_seeds(), localThresholdMS=1000)
client = await self.async_rs_client(
async_client_context.mongos_seeds(), localThresholdMS=1000
)
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
coll = client.test.test
# Create the collection.
await coll.insert_one({})
self.addAsyncCleanup(client.close)
async with client.start_session() as s:
# Session is pinned to Mongos.
async with await s.start_transaction():
@ -178,12 +188,13 @@ class TestTransactions(AsyncIntegrationTest):
async def test_unpin_for_non_transaction_operation(self):
# Increase localThresholdMS and wait until both nodes are discovered
# to avoid false positives.
client = await async_rs_client(async_client_context.mongos_seeds(), localThresholdMS=1000)
client = await self.async_rs_client(
async_client_context.mongos_seeds(), localThresholdMS=1000
)
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
coll = client.test.test
# Create the collection.
await coll.insert_one({})
self.addAsyncCleanup(client.close)
async with client.start_session() as s:
# Session is pinned to Mongos.
async with await s.start_transaction():
@ -307,11 +318,10 @@ class TestTransactions(AsyncIntegrationTest):
# Start a transaction with a batch of operations that needs to be
# split.
listener = OvertCommandListener()
client = await async_rs_client(event_listeners=[listener])
client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test
await coll.delete_many({})
listener.reset()
self.addAsyncCleanup(client.close)
self.addAsyncCleanup(coll.drop)
large_str = "\0" * (1 * 1024 * 1024)
ops: List[InsertOne[RawBSONDocument]] = [
@ -336,8 +346,7 @@ class TestTransactions(AsyncIntegrationTest):
@async_client_context.require_transactions
async def test_transaction_direct_connection(self):
client = await async_single_client()
self.addAsyncCleanup(client.close)
client = await self.async_single_client()
coll = client.pymongo_test.test
# Make sure the collection exists.
@ -393,14 +402,16 @@ class PatchSessionTimeout:
client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.real_timeout
class TestTransactionsConvenientAPI(AsyncIntegrationTest):
class TestTransactionsConvenientAPI(AsyncTransactionsBase):
@classmethod
async def _setup_class(cls):
await super()._setup_class()
cls.mongos_clients = []
if async_client_context.supports_transactions():
for address in async_client_context.mongoses:
cls.mongos_clients.append(await async_single_client("{}:{}".format(*address)))
cls.mongos_clients.append(
await cls.unmanaged_async_single_client("{}:{}".format(*address))
)
@classmethod
async def _tearDown_class(cls):
@ -450,8 +461,7 @@ class TestTransactionsConvenientAPI(AsyncIntegrationTest):
@async_client_context.require_transactions
async def test_callback_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = await async_rs_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test
async def callback(session):
@ -479,8 +489,7 @@ class TestTransactionsConvenientAPI(AsyncIntegrationTest):
@async_client_context.require_transactions
async def test_callback_not_retried_after_commit_timeout(self):
listener = OvertCommandListener()
client = await async_rs_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test
async def callback(session):
@ -514,8 +523,7 @@ class TestTransactionsConvenientAPI(AsyncIntegrationTest):
@async_client_context.require_transactions
async def test_commit_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = await async_rs_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test
async def callback(session):

View File

@ -25,7 +25,6 @@ from test.utils import (
EventListener,
OvertCommandListener,
ServerAndTopologyEventListener,
async_rs_client,
camel_to_snake,
camel_to_snake_args,
parse_spec_options,
@ -101,6 +100,8 @@ class AsyncSpecRunner(AsyncIntegrationTest):
@classmethod
async def _tearDown_class(cls):
cls.knobs.disable()
for client in cls.mongos_clients:
await client.close()
await super()._tearDown_class()
def setUp(self):
@ -527,7 +528,7 @@ class AsyncSpecRunner(AsyncIntegrationTest):
host = async_client_context.MULTI_MONGOS_LB_URI
elif async_client_context.is_mongos:
host = async_client_context.mongos_seeds()
client = await async_rs_client(
client = await self.async_rs_client(
h=host, event_listeners=[listener, pool_listener, server_listener], **client_options
)
self.scenario_client = client

View File

@ -18,6 +18,7 @@ from __future__ import annotations
import os
import sys
import unittest
from test import PyMongoTestCase
from unittest.mock import patch
import pytest
@ -36,7 +37,7 @@ from pymongo.uri_parser import parse_uri
pytestmark = pytest.mark.auth_aws
class TestAuthAWS(unittest.TestCase):
class TestAuthAWS(PyMongoTestCase):
uri: str
@classmethod
@ -69,7 +70,7 @@ class TestAuthAWS(unittest.TestCase):
self.skipTest("Not testing cached credentials")
# Make a connection to ensure that we enable caching.
client = MongoClient(self.uri)
client = self.simple_client(self.uri)
client.get_database().test.find_one()
client.close()
@ -79,7 +80,7 @@ class TestAuthAWS(unittest.TestCase):
auth.set_cached_credentials(None)
self.assertEqual(auth.get_cached_credentials(), None)
client = MongoClient(self.uri)
client = self.simple_client(self.uri)
client.get_database().test.find_one()
client.close()
return auth.get_cached_credentials()
@ -90,8 +91,7 @@ class TestAuthAWS(unittest.TestCase):
def test_cache_about_to_expire(self):
creds = self.setup_cache()
client = MongoClient(self.uri)
self.addCleanup(client.close)
client = self.simple_client(self.uri)
# Make the creds about to expire.
creds = auth.get_cached_credentials()
@ -107,8 +107,7 @@ class TestAuthAWS(unittest.TestCase):
def test_poisoned_cache(self):
creds = self.setup_cache()
client = MongoClient(self.uri)
self.addCleanup(client.close)
client = self.simple_client(self.uri)
# Poison the creds with invalid password.
assert creds is not None
@ -130,8 +129,7 @@ class TestAuthAWS(unittest.TestCase):
self.assertIsNotNone(creds)
os.environ.copy()
client = MongoClient(self.uri)
self.addCleanup(client.close)
client = self.simple_client(self.uri)
client.get_database().test.find_one()
@ -149,8 +147,7 @@ class TestAuthAWS(unittest.TestCase):
auth.set_cached_credentials(None)
client2 = MongoClient(self.uri)
self.addCleanup(client2.close)
client2 = self.simple_client(self.uri)
with patch.dict("os.environ", mock_env):
self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo")
@ -166,8 +163,7 @@ class TestAuthAWS(unittest.TestCase):
if creds.token:
mock_env["AWS_SESSION_TOKEN"] = creds.token
client = MongoClient(self.uri)
self.addCleanup(client.close)
client = self.simple_client(self.uri)
with patch.dict(os.environ, mock_env):
self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], creds.username)
@ -177,22 +173,19 @@ class TestAuthAWS(unittest.TestCase):
mock_env["AWS_ACCESS_KEY_ID"] = "foo"
client2 = MongoClient(self.uri)
self.addCleanup(client2.close)
client2 = self.simple_client(self.uri)
with patch.dict("os.environ", mock_env), self.assertRaises(OperationFailure):
self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo")
client2.get_database().test.find_one()
class TestAWSLambdaExamples(unittest.TestCase):
class TestAWSLambdaExamples(PyMongoTestCase):
def test_shared_client(self):
# Start AWS Lambda Example 1
import os
from pymongo import MongoClient
client = MongoClient(host=os.environ["MONGODB_URI"])
client = self.simple_client(host=os.environ["MONGODB_URI"])
def lambda_handler(event, context):
return client.db.command("ping")
@ -203,9 +196,7 @@ class TestAWSLambdaExamples(unittest.TestCase):
# Start AWS Lambda Example 2
import os
from pymongo import MongoClient
client = MongoClient(
client = self.simple_client(
host=os.environ["MONGODB_URI"],
authSource="$external",
authMechanism="MONGODB-AWS",

View File

@ -23,6 +23,7 @@ import unittest
import warnings
from contextlib import contextmanager
from pathlib import Path
from test import PyMongoTestCase
from typing import Dict
import pytest
@ -56,7 +57,7 @@ globals().update(generate_test_classes(str(TEST_PATH), module=__name__))
pytestmark = pytest.mark.auth_oidc
class OIDCTestBase(unittest.TestCase):
class OIDCTestBase(PyMongoTestCase):
@classmethod
def setUpClass(cls):
cls.uri_single = os.environ["MONGODB_URI_SINGLE"]
@ -94,6 +95,7 @@ class OIDCTestBase(unittest.TestCase):
yield
finally:
client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off")
client.close()
@pytest.mark.auth_oidc
@ -149,7 +151,9 @@ class TestAuthOIDCHuman(OIDCTestBase):
if not len(args):
args = [self.uri_single]
return MongoClient(*args, authmechanismproperties=props, **kwargs)
client = self.simple_client(*args, authmechanismproperties=props, **kwargs)
return client
def test_1_1_single_principal_implicit_username(self):
# Create default OIDC client with authMechanism=MONGODB-OIDC.

View File

@ -29,13 +29,12 @@ except ImportError:
from bson.objectid import ObjectId
from pymongo import MongoClient
from pymongo.errors import OperationFailure
pytestmark = pytest.mark.mockupdb
class TestCursor(unittest.TestCase):
class TestCursor(PyMongoTestCase):
def test_getmore_load_balanced(self):
server = MockupDB()
server.autoresponds(
@ -50,7 +49,7 @@ class TestCursor(unittest.TestCase):
server.run()
self.addCleanup(server.stop)
client = MongoClient(server.uri, loadBalanced=True)
client = self.simple_client(server.uri, loadBalanced=True)
self.addCleanup(client.close)
collection = client.db.coll
cursor = collection.find()
@ -77,7 +76,7 @@ class TestRetryableErrorCodeCatch(PyMongoTestCase):
self.addCleanup(server.stop)
server.autoresponds("ismaster", maxWireVersion=6)
client = MongoClient(server.uri)
client = self.simple_client(server.uri)
with going(lambda: server.receives(OpMsg({"find": "collection"})).command_err(code=code)):
cursor = client.db.collection.find()

View File

@ -48,8 +48,11 @@ else:
def _connect(options):
uri = f"mongodb://localhost:27017/?serverSelectionTimeoutMS={TIMEOUT_MS}&tlsCAFile={CA_FILE}&{options}"
print(uri)
client = pymongo.MongoClient(uri)
client.admin.command("ping")
try:
client = pymongo.MongoClient(uri)
client.admin.command("ping")
finally:
client.close()
class TestOCSP(unittest.TestCase):

View File

@ -23,16 +23,14 @@ from urllib.parse import quote_plus
sys.path[0:0] = [""]
from test import IntegrationTest, SkipTest, client_context, unittest
from test.utils import (
AllowListEventListener,
delay,
ignore_deprecations,
rs_or_single_client,
rs_or_single_client_noauth,
single_client,
single_client_noauth,
from test import (
IntegrationTest,
PyMongoTestCase,
SkipTest,
client_context,
unittest,
)
from test.utils import AllowListEventListener, delay, ignore_deprecations
from pymongo import MongoClient, monitoring
from pymongo.auth_shared import _build_credentials_tuple
@ -81,7 +79,7 @@ class AutoAuthenticateThread(threading.Thread):
self.success = True
class TestGSSAPI(unittest.TestCase):
class TestGSSAPI(PyMongoTestCase):
mech_properties: str
service_realm_required: bool
@ -138,7 +136,7 @@ class TestGSSAPI(unittest.TestCase):
if not self.service_realm_required:
# Without authMechanismProperties.
client = MongoClient(
client = self.simple_client(
GSSAPI_HOST,
GSSAPI_PORT,
username=GSSAPI_PRINCIPAL,
@ -149,11 +147,11 @@ class TestGSSAPI(unittest.TestCase):
client[GSSAPI_DB].collection.find_one()
# Log in using URI, without authMechanismProperties.
client = MongoClient(uri)
client = self.simple_client(uri)
client[GSSAPI_DB].collection.find_one()
# Authenticate with authMechanismProperties.
client = MongoClient(
client = self.simple_client(
GSSAPI_HOST,
GSSAPI_PORT,
username=GSSAPI_PRINCIPAL,
@ -166,14 +164,14 @@ class TestGSSAPI(unittest.TestCase):
# Log in using URI, with authMechanismProperties.
mech_uri = uri + f"&authMechanismProperties={self.mech_properties}"
client = MongoClient(mech_uri)
client = self.simple_client(mech_uri)
client[GSSAPI_DB].collection.find_one()
set_name = client_context.replica_set_name
if set_name:
if not self.service_realm_required:
# Without authMechanismProperties
client = MongoClient(
client = self.simple_client(
GSSAPI_HOST,
GSSAPI_PORT,
username=GSSAPI_PRINCIPAL,
@ -185,11 +183,11 @@ class TestGSSAPI(unittest.TestCase):
client[GSSAPI_DB].list_collection_names()
uri = uri + f"&replicaSet={set_name!s}"
client = MongoClient(uri)
client = self.simple_client(uri)
client[GSSAPI_DB].list_collection_names()
# With authMechanismProperties
client = MongoClient(
client = self.simple_client(
GSSAPI_HOST,
GSSAPI_PORT,
username=GSSAPI_PRINCIPAL,
@ -202,13 +200,13 @@ class TestGSSAPI(unittest.TestCase):
client[GSSAPI_DB].list_collection_names()
mech_uri = mech_uri + f"&replicaSet={set_name!s}"
client = MongoClient(mech_uri)
client = self.simple_client(mech_uri)
client[GSSAPI_DB].list_collection_names()
@ignore_deprecations
@client_context.require_sync
def test_gssapi_threaded(self):
client = MongoClient(
client = self.simple_client(
GSSAPI_HOST,
GSSAPI_PORT,
username=GSSAPI_PRINCIPAL,
@ -244,7 +242,7 @@ class TestGSSAPI(unittest.TestCase):
set_name = client_context.replica_set_name
if set_name:
client = MongoClient(
client = self.simple_client(
GSSAPI_HOST,
GSSAPI_PORT,
username=GSSAPI_PRINCIPAL,
@ -267,14 +265,14 @@ class TestGSSAPI(unittest.TestCase):
self.assertTrue(thread.success)
class TestSASLPlain(unittest.TestCase):
class TestSASLPlain(PyMongoTestCase):
@classmethod
def setUpClass(cls):
if not SASL_HOST or not SASL_USER or not SASL_PASS:
raise SkipTest("Must set SASL_HOST, SASL_USER, and SASL_PASS to test SASL")
def test_sasl_plain(self):
client = MongoClient(
client = self.simple_client(
SASL_HOST,
SASL_PORT,
username=SASL_USER,
@ -293,12 +291,12 @@ class TestSASLPlain(unittest.TestCase):
SASL_PORT,
SASL_DB,
)
client = MongoClient(uri)
client = self.simple_client(uri)
client.ldap.test.find_one()
set_name = client_context.replica_set_name
if set_name:
client = MongoClient(
client = self.simple_client(
SASL_HOST,
SASL_PORT,
replicaSet=set_name,
@ -317,7 +315,7 @@ class TestSASLPlain(unittest.TestCase):
SASL_DB,
str(set_name),
)
client = MongoClient(uri)
client = self.simple_client(uri)
client.ldap.test.find_one()
def test_sasl_plain_bad_credentials(self):
@ -331,8 +329,8 @@ class TestSASLPlain(unittest.TestCase):
)
return uri
bad_user = MongoClient(auth_string("not-user", SASL_PASS))
bad_pwd = MongoClient(auth_string(SASL_USER, "not-pwd"))
bad_user = self.simple_client(auth_string("not-user", SASL_PASS))
bad_pwd = self.simple_client(auth_string(SASL_USER, "not-pwd"))
# OperationFailure raised upon connecting.
with self.assertRaises(OperationFailure):
bad_user.admin.command("ping")
@ -354,7 +352,7 @@ class TestSCRAMSHA1(IntegrationTest):
def test_scram_sha1(self):
host, port = client_context.host, client_context.port
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
"mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" % (host, port)
)
client.pymongo_test.command("dbstats")
@ -365,7 +363,7 @@ class TestSCRAMSHA1(IntegrationTest):
"@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1"
"&replicaSet=%s" % (host, port, client_context.replica_set_name)
)
client = single_client_noauth(uri)
client = self.single_client_noauth(uri)
client.pymongo_test.command("dbstats")
db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY)
db.command("dbstats")
@ -393,7 +391,7 @@ class TestSCRAM(IntegrationTest):
"testscram", "sha256", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"]
)
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="sha256", password="pwd", authSource="testscram", event_listeners=[listener]
)
client.testscram.command("dbstats")
@ -430,36 +428,38 @@ class TestSCRAM(IntegrationTest):
)
# Step 2: verify auth success cases
client = rs_or_single_client_noauth(username="sha1", password="pwd", authSource="testscram")
client = self.rs_or_single_client_noauth(
username="sha1", password="pwd", authSource="testscram"
)
client.testscram.command("dbstats")
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1"
)
client.testscram.command("dbstats")
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="sha256", password="pwd", authSource="testscram"
)
client.testscram.command("dbstats")
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256"
)
client.testscram.command("dbstats")
# Step 2: SCRAM-SHA-1 and SCRAM-SHA-256
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1"
)
client.testscram.command("dbstats")
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256"
)
client.testscram.command("dbstats")
self.listener.reset()
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="both", password="pwd", authSource="testscram", event_listeners=[self.listener]
)
client.testscram.command("dbstats")
@ -472,19 +472,19 @@ class TestSCRAM(IntegrationTest):
self.assertEqual(started.command.get("mechanism"), "SCRAM-SHA-256")
# Step 3: verify auth failure conditions
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256"
)
with self.assertRaises(OperationFailure):
client.testscram.command("dbstats")
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1"
)
with self.assertRaises(OperationFailure):
client.testscram.command("dbstats")
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="not-a-user", password="pwd", authSource="testscram"
)
with self.assertRaises(OperationFailure):
@ -497,7 +497,7 @@ class TestSCRAM(IntegrationTest):
port,
client_context.replica_set_name,
)
client = single_client_noauth(uri)
client = self.single_client_noauth(uri)
client.testscram.command("dbstats")
db = client.get_database("testscram", read_preference=ReadPreference.SECONDARY)
db.command("dbstats")
@ -517,12 +517,12 @@ class TestSCRAM(IntegrationTest):
"testscram", "IX", "IX", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"]
)
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="\u2168", password="\u2163", authSource="testscram"
)
client.testscram.command("dbstats")
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="\u2168",
password="\u2163",
authSource="testscram",
@ -530,17 +530,17 @@ class TestSCRAM(IntegrationTest):
)
client.testscram.command("dbstats")
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="\u2168", password="IV", authSource="testscram"
)
client.testscram.command("dbstats")
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="IX", password="I\u00ADX", authSource="testscram"
)
client.testscram.command("dbstats")
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="IX",
password="I\u00ADX",
authSource="testscram",
@ -548,25 +548,29 @@ class TestSCRAM(IntegrationTest):
)
client.testscram.command("dbstats")
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
username="IX", password="IX", authSource="testscram", authMechanism="SCRAM-SHA-256"
)
client.testscram.command("dbstats")
client = rs_or_single_client_noauth(
client = self.rs_or_single_client_noauth(
"mongodb://\u2168:\u2163@%s:%d/testscram" % (host, port)
)
client.testscram.command("dbstats")
client = rs_or_single_client_noauth("mongodb://\u2168:IV@%s:%d/testscram" % (host, port))
client = self.rs_or_single_client_noauth(
"mongodb://\u2168:IV@%s:%d/testscram" % (host, port)
)
client.testscram.command("dbstats")
client = rs_or_single_client_noauth("mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port))
client = self.rs_or_single_client_noauth(
"mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port)
)
client.testscram.command("dbstats")
client = rs_or_single_client_noauth("mongodb://IX:IX@%s:%d/testscram" % (host, port))
client = self.rs_or_single_client_noauth("mongodb://IX:IX@%s:%d/testscram" % (host, port))
client.testscram.command("dbstats")
def test_cache(self):
client = single_client()
client = self.single_client()
credentials = client.options.pool_options._credentials
cache = credentials.cache
self.assertIsNotNone(cache)
@ -591,8 +595,7 @@ class TestSCRAM(IntegrationTest):
coll.insert_one({"_id": 1})
# The first thread to call find() will authenticate
client = rs_or_single_client()
self.addCleanup(client.close)
client = self.rs_or_single_client()
coll = client.db.test
threads = []
for _ in range(4):
@ -619,7 +622,7 @@ class TestAuthURIOptions(IntegrationTest):
def test_uri_options(self):
# Test default to admin
host, port = client_context.host, client_context.port
client = rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))
client = self.rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))
self.assertTrue(client.admin.command("dbstats"))
if client_context.is_rs:
@ -628,14 +631,14 @@ class TestAuthURIOptions(IntegrationTest):
port,
client_context.replica_set_name,
)
client = single_client_noauth(uri)
client = self.single_client_noauth(uri)
self.assertTrue(client.admin.command("dbstats"))
db = client.get_database("admin", read_preference=ReadPreference.SECONDARY)
self.assertTrue(db.command("dbstats"))
# Test explicit database
uri = "mongodb://user:pass@%s:%d/pymongo_test" % (host, port)
client = rs_or_single_client_noauth(uri)
client = self.rs_or_single_client_noauth(uri)
with self.assertRaises(OperationFailure):
client.admin.command("dbstats")
self.assertTrue(client.pymongo_test.command("dbstats"))
@ -646,7 +649,7 @@ class TestAuthURIOptions(IntegrationTest):
port,
client_context.replica_set_name,
)
client = single_client_noauth(uri)
client = self.single_client_noauth(uri)
with self.assertRaises(OperationFailure):
client.admin.command("dbstats")
self.assertTrue(client.pymongo_test.command("dbstats"))
@ -655,7 +658,7 @@ class TestAuthURIOptions(IntegrationTest):
# Test authSource
uri = "mongodb://user:pass@%s:%d/pymongo_test2?authSource=pymongo_test" % (host, port)
client = rs_or_single_client_noauth(uri)
client = self.rs_or_single_client_noauth(uri)
with self.assertRaises(OperationFailure):
client.pymongo_test2.command("dbstats")
self.assertTrue(client.pymongo_test.command("dbstats"))
@ -665,7 +668,7 @@ class TestAuthURIOptions(IntegrationTest):
"mongodb://user:pass@%s:%d/pymongo_test2?replicaSet="
"%s;authSource=pymongo_test" % (host, port, client_context.replica_set_name)
)
client = single_client_noauth(uri)
client = self.single_client_noauth(uri)
with self.assertRaises(OperationFailure):
client.pymongo_test2.command("dbstats")
self.assertTrue(client.pymongo_test.command("dbstats"))

View File

@ -20,6 +20,7 @@ import json
import os
import sys
import warnings
from test import PyMongoTestCase
sys.path[0:0] = [""]
@ -34,7 +35,7 @@ _IS_SYNC = True
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth")
class TestAuthSpec(unittest.TestCase):
class TestAuthSpec(PyMongoTestCase):
pass
@ -54,7 +55,7 @@ def create_test(test_case):
warnings.simplefilter("default")
self.assertRaises(Exception, MongoClient, uri, connect=False)
else:
client = MongoClient(uri, connect=False)
client = self.simple_client(uri, connect=False)
credentials = client.options.pool_options._credentials
if credential is None:
self.assertIsNone(credentials)

View File

@ -24,22 +24,13 @@ from pymongo.synchronous.mongo_client import MongoClient
sys.path[0:0] = [""]
from test import IntegrationTest, client_context, remove_all_users, unittest
from test.utils import (
rs_or_single_client_noauth,
single_client,
wait_until,
)
from test.utils import wait_until
from bson.binary import Binary, UuidRepresentation
from bson.codec_options import CodecOptions
from bson.objectid import ObjectId
from pymongo.common import partition_node
from pymongo.errors import (
BulkWriteError,
ConfigurationError,
InvalidOperation,
OperationFailure,
)
from pymongo.errors import BulkWriteError, ConfigurationError, InvalidOperation, OperationFailure
from pymongo.operations import *
from pymongo.synchronous.collection import Collection
from pymongo.write_concern import WriteConcern
@ -913,7 +904,7 @@ class TestBulkAuthorization(BulkAuthorizationTestBase):
def test_readonly(self):
# We test that an authorization failure aborts the batch and is raised
# as OperationFailure.
cli = rs_or_single_client_noauth(
cli = self.rs_or_single_client_noauth(
username="readonly", password="pw", authSource="pymongo_test"
)
coll = cli.pymongo_test.test
@ -924,7 +915,7 @@ class TestBulkAuthorization(BulkAuthorizationTestBase):
def test_no_remove(self):
# We test that an authorization failure aborts the batch and is raised
# as OperationFailure.
cli = rs_or_single_client_noauth(
cli = self.rs_or_single_client_noauth(
username="noremove", password="pw", authSource="pymongo_test"
)
coll = cli.pymongo_test.test
@ -952,7 +943,7 @@ class TestBulkWriteConcern(BulkTestBase):
if cls.w is not None and cls.w > 1:
for member in (client_context.hello)["hosts"]:
if member != (client_context.hello)["primary"]:
cls.secondary = single_client(*partition_node(member))
cls.secondary = cls.unmanaged_single_client(*partition_node(member))
break
@classmethod

View File

@ -28,12 +28,17 @@ from typing import no_type_check
sys.path[0:0] = [""]
from test import IntegrationTest, Version, client_context, unittest
from test import (
IntegrationTest,
PyMongoTestCase,
Version,
client_context,
unittest,
)
from test.unified_format import generate_test_classes
from test.utils import (
AllowListEventListener,
EventListener,
rs_or_single_client,
wait_until,
)
@ -69,8 +74,7 @@ class TestChangeStreamBase(IntegrationTest):
def client_with_listener(self, *commands):
"""Return a client with a AllowListEventListener."""
listener = AllowListEventListener(*commands)
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
return client, listener
def watched_collection(self, *args, **kwargs):
@ -174,7 +178,7 @@ class APITestsMixin:
@no_type_check
def test_try_next_runs_one_getmore(self):
listener = EventListener()
client = rs_or_single_client(event_listeners=[listener])
client = self.rs_or_single_client(event_listeners=[listener])
# Connect to the cluster.
client.admin.command("ping")
listener.reset()
@ -232,7 +236,7 @@ class APITestsMixin:
@no_type_check
def test_batch_size_is_honored(self):
listener = EventListener()
client = rs_or_single_client(event_listeners=[listener])
client = self.rs_or_single_client(event_listeners=[listener])
# Connect to the cluster.
client.admin.command("ping")
listener.reset()
@ -473,7 +477,7 @@ class ProseSpecTestsMixin:
@no_type_check
def _client_with_listener(self, *commands):
listener = AllowListEventListener(*commands)
client = rs_or_single_client(event_listeners=[listener])
client = PyMongoTestCase.unmanaged_rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
return client, listener
@ -1111,7 +1115,7 @@ class TestAllLegacyScenarios(IntegrationTest):
def _setup_class(cls):
super()._setup_class()
cls.listener = AllowListEventListener("aggregate", "getMore")
cls.client = rs_or_single_client(event_listeners=[cls.listener])
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
@classmethod
def _tearDown_class(cls):

File diff suppressed because it is too large Load Diff

View File

@ -27,7 +27,6 @@ from test import (
)
from test.utils import (
OvertCommandListener,
rs_or_single_client,
)
from unittest.mock import patch
@ -38,7 +37,6 @@ from pymongo.errors import (
InvalidOperation,
NetworkTimeout,
)
from pymongo.monitoring import *
from pymongo.operations import *
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.write_concern import WriteConcern
@ -97,8 +95,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
@client_context.require_no_serverless
def test_batch_splits_if_num_operations_too_large(self):
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
models = []
for _ in range(self.max_write_batch_size + 1):
@ -123,8 +120,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
@client_context.require_no_serverless
def test_batch_splits_if_ops_payload_too_large(self):
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
models = []
num_models = int(self.max_message_size_bytes / self.max_bson_object_size + 1)
@ -157,11 +153,10 @@ class TestClientBulkWriteCRUD(IntegrationTest):
@client_context.require_failCommand_fail_point
def test_collects_write_concern_errors_across_batches(self):
listener = OvertCommandListener()
client = rs_or_single_client(
client = self.rs_or_single_client(
event_listeners=[listener],
retryWrites=False,
)
self.addCleanup(client.close)
fail_command = {
"configureFailPoint": "failCommand",
@ -200,8 +195,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
@client_context.require_no_serverless
def test_collects_write_errors_across_batches_unordered(self):
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
collection = client.db["coll"]
self.addCleanup(collection.drop)
@ -231,8 +225,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
@client_context.require_no_serverless
def test_collects_write_errors_across_batches_ordered(self):
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
collection = client.db["coll"]
self.addCleanup(collection.drop)
@ -262,8 +255,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
@client_context.require_no_serverless
def test_handles_cursor_requiring_getMore(self):
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
collection = client.db["coll"]
self.addCleanup(collection.drop)
@ -304,8 +296,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
@client_context.require_no_standalone
def test_handles_cursor_requiring_getMore_within_transaction(self):
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
collection = client.db["coll"]
self.addCleanup(collection.drop)
@ -348,8 +339,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
@client_context.require_failCommand_fail_point
def test_handles_getMore_error(self):
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
collection = client.db["coll"]
self.addCleanup(collection.drop)
@ -403,8 +393,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
@client_context.require_no_serverless
def test_returns_error_if_unacknowledged_too_large_insert(self):
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
b_repeated = "b" * self.max_bson_object_size
@ -460,8 +449,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
@client_context.require_no_serverless
def test_no_batch_splits_if_new_namespace_is_not_too_large(self):
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
num_models, models = self._setup_namespace_test_models()
models.append(
@ -492,8 +480,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
@client_context.require_no_serverless
def test_batch_splits_if_new_namespace_is_too_large(self):
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
num_models, models = self._setup_namespace_test_models()
c_repeated = "c" * 200
@ -530,8 +517,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
@client_context.require_version_min(8, 0, 0, -24)
@client_context.require_no_serverless
def test_returns_error_if_no_writes_can_be_added_to_ops(self):
client = rs_or_single_client()
self.addCleanup(client.close)
client = self.rs_or_single_client()
# Document too large.
b_repeated = "b" * self.max_message_size_bytes
@ -554,8 +540,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
key_vault_namespace="db.coll",
kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}},
)
client = rs_or_single_client(auto_encryption_opts=opts)
self.addCleanup(client.close)
client = self.rs_or_single_client(auto_encryption_opts=opts)
models = [InsertOne(namespace="db.coll", document={"a": "b"})]
with self.assertRaises(InvalidOperation) as context:
@ -580,7 +565,7 @@ class TestClientBulkWriteCSOT(IntegrationTest):
def test_timeout_in_multi_batch_bulk_write(self):
_OVERHEAD = 500
internal_client = rs_or_single_client(timeoutMS=None)
internal_client = self.rs_or_single_client(timeoutMS=None)
self.addCleanup(internal_client.close)
collection = internal_client.db["coll"]
@ -605,14 +590,13 @@ class TestClientBulkWriteCSOT(IntegrationTest):
)
listener = OvertCommandListener()
client = rs_or_single_client(
client = self.rs_or_single_client(
event_listeners=[listener],
readConcernLevel="majority",
readPreference="primary",
timeoutMS=2000,
w="majority",
)
self.addCleanup(client.close)
client.admin.command("ping") # Init the client first.
with self.assertRaises(ClientBulkWriteException) as context:
client.bulk_write(models=models)

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import functools
import warnings
from test import IntegrationTest, client_context, unittest
from test.utils import EventListener, rs_or_single_client
from test.utils import EventListener
from typing import Any
from pymongo.collation import (
@ -99,7 +99,7 @@ class TestCollation(IntegrationTest):
def setUpClass(cls):
super().setUpClass()
cls.listener = EventListener()
cls.client = rs_or_single_client(event_listeners=[cls.listener])
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
cls.db = cls.client.pymongo_test
cls.collation = Collation("en_US")
cls.warn_context = warnings.catch_warnings()

View File

@ -29,6 +29,7 @@ sys.path[0:0] = [""]
from test import ( # TODO: fix sync imports in PYTHON-4528
IntegrationTest,
UnitTest,
client_context,
unittest,
)
@ -37,8 +38,6 @@ from test.utils import (
EventListener,
get_pool,
is_mongos,
rs_or_single_client,
single_client,
wait_until,
)
@ -81,14 +80,20 @@ from pymongo.write_concern import WriteConcern
_IS_SYNC = True
class TestCollectionNoConnect(unittest.TestCase):
class TestCollectionNoConnect(UnitTest):
"""Test Collection features on a client that does not connect."""
db: Database
client: MongoClient
@classmethod
def setUpClass(cls):
cls.db = MongoClient(connect=False).pymongo_test
def _setup_class(cls):
cls.client = MongoClient(connect=False)
cls.db = cls.client.pymongo_test
@classmethod
def _tearDown_class(cls):
cls.client.close()
def test_collection(self):
self.assertRaises(TypeError, Collection, self.db, 5)
@ -1800,8 +1805,7 @@ class TestCollection(IntegrationTest):
# Insert enough documents to require more than one batch
self.db.test.insert_many([{"i": i} for i in range(150)])
client = rs_or_single_client(maxPoolSize=1)
self.addCleanup(client.close)
client = self.rs_or_single_client(maxPoolSize=1)
pool = get_pool(client)
# Make sure the socket is returned after exhaustion.
@ -2077,7 +2081,7 @@ class TestCollection(IntegrationTest):
def test_find_one_and_write_concern(self):
listener = EventListener()
db = (single_client(event_listeners=[listener]))[self.db.name]
db = (self.single_client(event_listeners=[listener]))[self.db.name]
# non-default WriteConcern.
c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0))
# default WriteConcern.

View File

@ -22,7 +22,7 @@ import sys
sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test.utils import EventListener, rs_or_single_client
from test.utils import EventListener
from bson.dbref import DBRef
from pymongo.operations import IndexModel
@ -109,7 +109,7 @@ class TestComment(IntegrationTest):
@client_context.require_replica_set
def test_database_helpers(self):
listener = EventListener()
db = rs_or_single_client(event_listeners=[listener]).db
db = self.rs_or_single_client(event_listeners=[listener]).db
helpers = [
(db.watch, []),
(db.command, ["hello"]),
@ -126,7 +126,7 @@ class TestComment(IntegrationTest):
@client_context.require_replica_set
def test_client_helpers(self):
listener = EventListener()
cli = rs_or_single_client(event_listeners=[listener])
cli = self.rs_or_single_client(event_listeners=[listener])
helpers = [
(cli.watch, []),
(cli.list_databases, []),
@ -141,7 +141,7 @@ class TestComment(IntegrationTest):
@client_context.require_version_min(4, 7, -1)
def test_collection_helpers(self):
listener = EventListener()
db = rs_or_single_client(event_listeners=[listener])[self.db.name]
db = self.rs_or_single_client(event_listeners=[listener])[self.db.name]
coll = db.get_collection("test")
helpers = [

View File

@ -21,7 +21,6 @@ import uuid
sys.path[0:0] = [""]
from test import IntegrationTest, client_context, connected, unittest
from test.utils import rs_or_single_client, single_client
from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation
from bson.codec_options import CodecOptions
@ -111,10 +110,10 @@ class TestCommon(IntegrationTest):
)
def test_write_concern(self):
c = rs_or_single_client(connect=False)
c = self.rs_or_single_client(connect=False)
self.assertEqual(WriteConcern(), c.write_concern)
c = rs_or_single_client(connect=False, w=2, wTimeoutMS=1000)
c = self.rs_or_single_client(connect=False, w=2, wTimeoutMS=1000)
wc = WriteConcern(w=2, wtimeout=1000)
self.assertEqual(wc, c.write_concern)
@ -134,7 +133,7 @@ class TestCommon(IntegrationTest):
def test_mongo_client(self):
pair = client_context.pair
m = rs_or_single_client(w=0)
m = self.rs_or_single_client(w=0)
coll = m.pymongo_test.write_concern_test
coll.drop()
doc = {"_id": ObjectId()}
@ -143,17 +142,19 @@ class TestCommon(IntegrationTest):
coll = coll.with_options(write_concern=WriteConcern(w=1))
self.assertRaises(OperationFailure, coll.insert_one, doc)
m = rs_or_single_client()
m = self.rs_or_single_client()
coll = m.pymongo_test.write_concern_test
new_coll = coll.with_options(write_concern=WriteConcern(w=0))
self.assertTrue(new_coll.insert_one(doc))
self.assertRaises(OperationFailure, coll.insert_one, doc)
m = rs_or_single_client(f"mongodb://{pair}/", replicaSet=client_context.replica_set_name)
m = self.rs_or_single_client(
f"mongodb://{pair}/", replicaSet=client_context.replica_set_name
)
coll = m.pymongo_test.write_concern_test
self.assertRaises(OperationFailure, coll.insert_one, doc)
m = rs_or_single_client(
m = self.rs_or_single_client(
f"mongodb://{pair}/?w=0", replicaSet=client_context.replica_set_name
)
@ -161,8 +162,8 @@ class TestCommon(IntegrationTest):
coll.insert_one(doc)
# Equality tests
direct = connected(single_client(w=0))
direct2 = connected(single_client(f"mongodb://{pair}/?w=0", **self.credentials))
direct = connected(self.single_client(w=0))
direct2 = connected(self.single_client(f"mongodb://{pair}/?w=0", **self.credentials))
self.assertEqual(direct, direct2)
self.assertFalse(direct != direct2)

View File

@ -30,9 +30,6 @@ from test.utils import (
client_context,
get_pool,
get_pools,
rs_or_single_client,
single_client,
single_client_noauth,
wait_until,
)
from test.utils_spec_runner import SpecRunnerThread
@ -250,7 +247,7 @@ class TestCMAP(IntegrationTest):
else:
kill_cursor_frequency = interval / 1000.0
with client_knobs(kill_cursor_frequency=kill_cursor_frequency, min_heartbeat_interval=0.05):
client = single_client(**opts)
client = self.single_client(**opts)
# Update the SD to a known type because the DummyMonitor will not.
# Note we cannot simply call topology.on_change because that would
# internally call pool.ready() which introduces unexpected
@ -323,13 +320,13 @@ class TestCMAP(IntegrationTest):
# Prose tests. Numbers correspond to the prose test number in the spec.
#
def test_1_client_connection_pool_options(self):
client = rs_or_single_client(**self.POOL_OPTIONS)
client = self.rs_or_single_client(**self.POOL_OPTIONS)
self.addCleanup(client.close)
pool_opts = get_pool(client).opts
self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS)
def test_2_all_client_pools_have_same_options(self):
client = rs_or_single_client(**self.POOL_OPTIONS)
client = self.rs_or_single_client(**self.POOL_OPTIONS)
self.addCleanup(client.close)
client.admin.command("ping")
# Discover at least one secondary.
@ -345,14 +342,14 @@ class TestCMAP(IntegrationTest):
def test_3_uri_connection_pool_options(self):
opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()])
uri = f"mongodb://{client_context.pair}/?{opts}"
client = rs_or_single_client(uri)
client = self.rs_or_single_client(uri)
self.addCleanup(client.close)
pool_opts = get_pool(client).opts
self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS)
def test_4_subscribe_to_events(self):
listener = CMAPListener()
client = single_client(event_listeners=[listener])
client = self.single_client(event_listeners=[listener])
self.addCleanup(client.close)
self.assertEqual(listener.event_count(PoolCreatedEvent), 1)
@ -376,7 +373,7 @@ class TestCMAP(IntegrationTest):
def test_5_check_out_fails_connection_error(self):
listener = CMAPListener()
client = single_client(event_listeners=[listener])
client = self.single_client(event_listeners=[listener])
self.addCleanup(client.close)
pool = get_pool(client)
@ -403,7 +400,7 @@ class TestCMAP(IntegrationTest):
@client_context.require_no_fips
def test_5_check_out_fails_auth_error(self):
listener = CMAPListener()
client = single_client_noauth(
client = self.single_client_noauth(
username="notauser", password="fail", event_listeners=[listener]
)
self.addCleanup(client.close)
@ -449,7 +446,7 @@ class TestCMAP(IntegrationTest):
def test_close_leaves_pool_unpaused(self):
listener = CMAPListener()
client = single_client(event_listeners=[listener])
client = self.single_client(event_listeners=[listener])
client.admin.command("ping")
pool = get_pool(client)
client.close()

View File

@ -24,7 +24,6 @@ from test.utils import (
CMAPListener,
ensure_all_connected,
repl_set_step_down,
rs_or_single_client,
)
from bson import SON
@ -43,7 +42,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
def setUpClass(cls):
super().setUpClass()
cls.listener = CMAPListener()
cls.client = rs_or_single_client(
cls.client = cls.unmanaged_rs_or_single_client(
event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500
)

View File

@ -34,8 +34,8 @@ from test.utils import (
AllowListEventListener,
EventListener,
OvertCommandListener,
delay,
ignore_deprecations,
rs_or_single_client,
wait_until,
)
@ -43,7 +43,7 @@ from bson import decode_all
from bson.code import Code
from pymongo import ASCENDING, DESCENDING
from pymongo.collation import Collation
from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure
from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure, PyMongoError
from pymongo.operations import _IndexList
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
@ -230,7 +230,7 @@ class TestCursor(IntegrationTest):
self.assertEqual(90, cursor._max_await_time_ms)
listener = AllowListEventListener("find", "getMore")
coll = (rs_or_single_client(event_listeners=[listener]))[self.db.name].pymongo_test
coll = (self.rs_or_single_client(event_listeners=[listener]))[self.db.name].pymongo_test
# Tailable_defaults.
coll.find(cursor_type=CursorType.TAILABLE_AWAIT).to_list()
@ -345,8 +345,7 @@ class TestCursor(IntegrationTest):
def test_explain_with_read_concern(self):
# Do not add readConcern level to explain.
listener = AllowListEventListener("explain")
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local"))
self.assertTrue(coll.find().explain())
started = listener.started_events
@ -1252,8 +1251,7 @@ class TestCursor(IntegrationTest):
self.client._process_periodic_tasks()
listener = AllowListEventListener("killCursors")
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
coll = client[self.db.name].test_close_kills_cursors
# Add some test data.
@ -1291,8 +1289,7 @@ class TestCursor(IntegrationTest):
@client_context.require_failCommand_appName
def test_timeout_kills_cursor_synchronously(self):
listener = AllowListEventListener("killCursors")
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
coll = client[self.db.name].test_timeout_kills_cursor
# Add some test data.
@ -1349,8 +1346,7 @@ class TestCursor(IntegrationTest):
def test_getMore_does_not_send_readPreference(self):
listener = AllowListEventListener("find", "getMore")
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
# We never send primary read preference so override the default.
coll = client[self.db.name].get_collection(
"test", read_preference=ReadPreference.PRIMARY_PREFERRED
@ -1406,6 +1402,18 @@ class TestCursor(IntegrationTest):
docs = c.to_list(3)
self.assertEqual(len(docs), 2)
def test_to_list_csot_applied(self):
client = self.single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
client.admin.command("ping")
coll = client.pymongo.test
coll.insert_many([{} for _ in range(5)])
cursor = coll.find({"$where": delay(1)})
with self.assertRaises(PyMongoError) as ctx:
cursor.to_list()
self.assertTrue(ctx.exception.timeout)
@client_context.require_change_streams
def test_command_cursor_to_list(self):
# Set maxAwaitTimeMS=1 to speed up the test.
@ -1435,6 +1443,25 @@ class TestCursor(IntegrationTest):
result = db.test.aggregate([pipeline])
self.assertEqual(len(result.to_list(1)), 1)
@client_context.require_failCommand_blockConnection
def test_command_cursor_to_list_csot_applied(self):
client = self.single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
client.admin.command("ping")
coll = client.pymongo.test
coll.insert_many([{} for _ in range(5)])
fail_command = {
"configureFailPoint": "failCommand",
"mode": {"times": 5},
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 1000},
}
cursor = coll.aggregate([], batchSize=1)
with self.fail_point(fail_command):
with self.assertRaises(PyMongoError) as ctx:
cursor.to_list()
self.assertTrue(ctx.exception.timeout)
class TestRawBatchCursor(IntegrationTest):
def test_find_raw(self):
@ -1454,7 +1481,7 @@ class TestRawBatchCursor(IntegrationTest):
c.insert_many(docs)
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
client = self.rs_or_single_client(event_listeners=[listener])
with client.start_session() as session:
with session.start_transaction():
batches = (
@ -1484,7 +1511,7 @@ class TestRawBatchCursor(IntegrationTest):
c.insert_many(docs)
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener], retryReads=True)
client = self.rs_or_single_client(event_listeners=[listener], retryReads=True)
with self.fail_point(
{"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}}
):
@ -1505,7 +1532,7 @@ class TestRawBatchCursor(IntegrationTest):
c.insert_many(docs)
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener], retryReads=True)
client = self.rs_or_single_client(event_listeners=[listener], retryReads=True)
db = client[self.db.name]
with client.start_session(snapshot=True) as session:
db.test.distinct("x", {}, session=session)
@ -1566,7 +1593,7 @@ class TestRawBatchCursor(IntegrationTest):
def test_monitoring(self):
listener = EventListener()
client = rs_or_single_client(event_listeners=[listener])
client = self.rs_or_single_client(event_listeners=[listener])
c = client.pymongo_test.test
c.drop()
c.insert_many([{"_id": i} for i in range(10)])
@ -1632,7 +1659,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
c.insert_many(docs)
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
client = self.rs_or_single_client(event_listeners=[listener])
with client.start_session() as session:
with session.start_transaction():
batches = (
@ -1663,7 +1690,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
c.insert_many(docs)
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener], retryReads=True)
client = self.rs_or_single_client(event_listeners=[listener], retryReads=True)
with self.fail_point(
{"mode": {"times": 1}, "data": {"failCommands": ["aggregate"], "closeConnection": True}}
):
@ -1687,7 +1714,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
c.insert_many(docs)
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener], retryReads=True)
client = self.rs_or_single_client(event_listeners=[listener], retryReads=True)
db = client[self.db.name]
with client.start_session(snapshot=True) as session:
db.test.distinct("x", {}, session=session)
@ -1733,7 +1760,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
def test_monitoring(self):
listener = EventListener()
client = rs_or_single_client(event_listeners=[listener])
client = self.rs_or_single_client(event_listeners=[listener])
c = client.pymongo_test.test
c.drop()
c.insert_many([{"_id": i} for i in range(10)])
@ -1777,8 +1804,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
@client_context.require_no_mongos
def test_exhaust_cursor_db_set(self):
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
c = client.pymongo_test.test
c.delete_many({})
c.insert_many([{"_id": i} for i in range(3)])

View File

@ -27,7 +27,6 @@ sys.path[0:0] = [""]
from test import client_context, unittest
from test.test_client import IntegrationTest
from test.utils import rs_client
from bson import (
_BUILT_IN_TYPES,
@ -971,7 +970,7 @@ class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustom
if codec_options:
kwargs["type_registry"] = codec_options.type_registry
kwargs["document_class"] = codec_options.document_class
self.watched_target = rs_client(*args, **kwargs)
self.watched_target = self.rs_client(*args, **kwargs)
self.addCleanup(self.watched_target.close)
self.input_target = self.watched_target[self.db.name].test
# Insert a record to ensure db, coll are created.

View File

@ -27,8 +27,6 @@ from test import IntegrationTest, client_context, unittest
from test.unified_format import generate_test_classes
from test.utils import (
OvertCommandListener,
rs_client_noauth,
rs_or_single_client,
)
pytestmark = pytest.mark.data_lake
@ -65,7 +63,7 @@ class TestDataLakeProse(IntegrationTest):
# Test killCursors
def test_1(self):
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
client = self.rs_or_single_client(event_listeners=[listener])
cursor = client[self.TEST_DB][self.TEST_COLLECTION].find({}, batch_size=2)
next(cursor)
@ -90,13 +88,13 @@ class TestDataLakeProse(IntegrationTest):
# Test no auth
def test_2(self):
client = rs_client_noauth()
client = self.rs_client_noauth()
client.admin.command("ping")
# Test with auth
def test_3(self):
for mechanism in ["SCRAM-SHA-1", "SCRAM-SHA-256"]:
client = rs_or_single_client(authMechanism=mechanism)
client = self.rs_or_single_client(authMechanism=mechanism)
client[self.TEST_DB][self.TEST_COLLECTION].find_one()

View File

@ -28,7 +28,6 @@ from test.test_custom_types import DECIMAL_CODECOPTS
from test.utils import (
IMPOSSIBLE_WRITE_CONCERN,
OvertCommandListener,
rs_or_single_client,
wait_until,
)
@ -207,7 +206,7 @@ class TestDatabase(IntegrationTest):
def test_list_collection_names_filter(self):
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
client = self.rs_or_single_client(event_listeners=[listener])
db = client[self.db.name]
db.capped.drop()
db.create_collection("capped", capped=True, size=4096)
@ -234,8 +233,7 @@ class TestDatabase(IntegrationTest):
def test_check_exists(self):
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener])
db = client[self.db.name]
db.drop_collection("unique")
db.create_collection("unique", check_exists=True)
@ -323,7 +321,7 @@ class TestDatabase(IntegrationTest):
self.client.drop_database("pymongo_test")
def test_list_collection_names_single_socket(self):
client = rs_or_single_client(maxPoolSize=1)
client = self.rs_or_single_client(maxPoolSize=1)
client.drop_database("test_collection_names_single_socket")
db = client.test_collection_names_single_socket
for i in range(200):

View File

@ -22,7 +22,7 @@ import threading
sys.path[0:0] = [""]
from test import IntegrationTest, unittest
from test import IntegrationTest, PyMongoTestCase, unittest
from test.pymongo_mocks import DummyMonitor
from test.unified_format import generate_test_classes
from test.utils import (
@ -32,9 +32,7 @@ from test.utils import (
assertion_context,
client_context,
get_pool,
rs_or_single_client,
server_name_to_type,
single_client,
wait_until,
)
from unittest.mock import patch
@ -272,7 +270,7 @@ class TestIgnoreStaleErrors(IntegrationTest):
def test_ignore_stale_connection_errors(self):
N_THREADS = 5
barrier = threading.Barrier(N_THREADS, timeout=30)
client = rs_or_single_client(minPoolSize=N_THREADS)
client = self.rs_or_single_client(minPoolSize=N_THREADS)
self.addCleanup(client.close)
# Wait for initial discovery.
@ -319,7 +317,7 @@ class TestPoolManagement(IntegrationTest):
def test_pool_unpause(self):
# This test implements the prose test "Connection Pool Management"
listener = CMAPHeartbeatListener()
client = single_client(
client = self.single_client(
appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener]
)
self.addCleanup(client.close)
@ -353,7 +351,7 @@ class TestServerMonitoringMode(IntegrationTest):
super().setUp()
def test_rtt_connection_is_enabled_stream(self):
client = rs_or_single_client(serverMonitoringMode="stream")
client = self.rs_or_single_client(serverMonitoringMode="stream")
self.addCleanup(client.close)
client.admin.command("ping")
@ -373,7 +371,7 @@ class TestServerMonitoringMode(IntegrationTest):
wait_until(predicate, "find all RTT monitors")
def test_rtt_connection_is_disabled_poll(self):
client = rs_or_single_client(serverMonitoringMode="poll")
client = self.rs_or_single_client(serverMonitoringMode="poll")
self.addCleanup(client.close)
self.assert_rtt_connection_is_disabled(client)
@ -387,7 +385,7 @@ class TestServerMonitoringMode(IntegrationTest):
]
for env in envs:
with patch.dict("os.environ", env):
client = rs_or_single_client(serverMonitoringMode="auto")
client = self.rs_or_single_client(serverMonitoringMode="auto")
self.addCleanup(client.close)
self.assert_rtt_connection_is_disabled(client)
@ -415,7 +413,7 @@ class TCPServer(socketserver.TCPServer):
self.server_close()
class TestHeartbeatStartOrdering(unittest.TestCase):
class TestHeartbeatStartOrdering(PyMongoTestCase):
def test_heartbeat_start_ordering(self):
events = []
listener = HeartbeatEventsListListener(events)
@ -423,7 +421,7 @@ class TestHeartbeatStartOrdering(unittest.TestCase):
server.events = events
server_thread = threading.Thread(target=server.handle_request_and_shutdown)
server_thread.start()
_c = MongoClient(
_c = self.simple_client(
"mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,)
)
server_thread.join()

View File

@ -22,16 +22,15 @@ import sys
sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test import IntegrationTest, PyMongoTestCase, client_context, unittest
from test.utils import wait_until
from pymongo.common import validate_read_preference_tags
from pymongo.errors import ConfigurationError
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.uri_parser import parse_uri, split_hosts
class TestDNSRepl(unittest.TestCase):
class TestDNSRepl(PyMongoTestCase):
TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "replica-set"
)
@ -42,7 +41,7 @@ class TestDNSRepl(unittest.TestCase):
pass
class TestDNSLoadBalanced(unittest.TestCase):
class TestDNSLoadBalanced(PyMongoTestCase):
TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "load-balanced"
)
@ -53,7 +52,7 @@ class TestDNSLoadBalanced(unittest.TestCase):
pass
class TestDNSSharded(unittest.TestCase):
class TestDNSSharded(PyMongoTestCase):
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "sharded")
load_balanced = False
@ -120,7 +119,7 @@ def create_test(test_case):
# tests.
copts["tlsAllowInvalidHostnames"] = True
client = MongoClient(uri, **copts)
client = PyMongoTestCase.unmanaged_simple_client(uri, **copts)
if num_seeds is not None:
self.assertEqual(len(client._topology_settings.seeds), num_seeds)
if hosts is not None:
@ -133,6 +132,7 @@ def create_test(test_case):
client.admin.command("ping")
# XXX: we should block until SRV poller runs at least once
# and re-run these assertions.
client.close()
else:
try:
parse_uri(uri)
@ -157,37 +157,37 @@ create_tests(TestDNSLoadBalanced)
create_tests(TestDNSSharded)
class TestParsingErrors(unittest.TestCase):
class TestParsingErrors(PyMongoTestCase):
def test_invalid_host(self):
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: mongodb is not",
MongoClient,
self.simple_client,
"mongodb+srv://mongodb",
)
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: mongodb.com is not",
MongoClient,
self.simple_client,
"mongodb+srv://mongodb.com",
)
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: an IP address is not",
MongoClient,
self.simple_client,
"mongodb+srv://127.0.0.1",
)
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: an IP address is not",
MongoClient,
self.simple_client,
"mongodb+srv://[::1]",
)
class TestCaseInsensitive(IntegrationTest):
def test_connect_case_insensitive(self):
client = MongoClient("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
self.addCleanup(client.close)
self.assertGreater(len(client.topology_description.server_descriptions()), 1)

View File

@ -31,7 +31,7 @@ import warnings
from test import IntegrationTest, PyMongoTestCase, client_context
from test.test_bulk import BulkTestBase
from threading import Thread
from typing import Any, Dict, Mapping
from typing import Any, Dict, Mapping, Optional
import pytest
@ -53,6 +53,7 @@ from test.helpers import (
KMIP_CREDS,
LOCAL_MASTER_KEY,
)
from test.test_bulk import BulkTestBase
from test.unified_format import generate_test_classes
from test.utils import (
AllowListEventListener,
@ -61,7 +62,6 @@ from test.utils import (
TopologyEventListener,
camel_to_snake_args,
is_greenthread_patched,
rs_or_single_client,
wait_until,
)
from test.utils_spec_runner import SpecRunner
@ -109,13 +109,12 @@ class TestAutoEncryptionOpts(PyMongoTestCase):
@unittest.skipUnless(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is not installed")
def test_crypt_shared(self):
# Test that we can pick up crypt_shared lib automatically
client = MongoClient(
self.simple_client(
auto_encryption_opts=AutoEncryptionOpts(
KMS_PROVIDERS, "keyvault.datakeys", crypt_shared_lib_required=True
),
connect=False,
)
self.addCleanup(client.close)
@unittest.skipIf(_HAVE_PYMONGOCRYPT, "pymongocrypt is installed")
def test_init_requires_pymongocrypt(self):
@ -196,19 +195,16 @@ class TestAutoEncryptionOpts(PyMongoTestCase):
class TestClientOptions(PyMongoTestCase):
def test_default(self):
client = MongoClient(connect=False)
self.addCleanup(client.close)
client = self.simple_client(connect=False)
self.assertEqual(get_client_opts(client).auto_encryption_opts, None)
client = MongoClient(auto_encryption_opts=None, connect=False)
self.addCleanup(client.close)
client = self.simple_client(auto_encryption_opts=None, connect=False)
self.assertEqual(get_client_opts(client).auto_encryption_opts, None)
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
def test_kwargs(self):
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
client = MongoClient(auto_encryption_opts=opts, connect=False)
self.addCleanup(client.close)
client = self.simple_client(auto_encryption_opts=opts, connect=False)
self.assertEqual(get_client_opts(client).auto_encryption_opts, opts)
@ -229,6 +225,34 @@ class EncryptionIntegrationTest(IntegrationTest):
self.assertIsInstance(val, Binary)
self.assertEqual(val.subtype, UUID_SUBTYPE)
def create_client_encryption(
self,
kms_providers: Mapping[str, Any],
key_vault_namespace: str,
key_vault_client: MongoClient,
codec_options: CodecOptions,
kms_tls_options: Optional[Mapping[str, Any]] = None,
):
client_encryption = ClientEncryption(
kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options
)
self.addCleanup(client_encryption.close)
return client_encryption
@classmethod
def unmanaged_create_client_encryption(
cls,
kms_providers: Mapping[str, Any],
key_vault_namespace: str,
key_vault_client: MongoClient,
codec_options: CodecOptions,
kms_tls_options: Optional[Mapping[str, Any]] = None,
):
client_encryption = ClientEncryption(
kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options
)
return client_encryption
# Location of JSON test files.
if _IS_SYNC:
@ -260,8 +284,7 @@ def bson_data(*paths):
class TestClientSimple(EncryptionIntegrationTest):
def _test_auto_encrypt(self, opts):
client = rs_or_single_client(auto_encryption_opts=opts)
self.addCleanup(client.close)
client = self.rs_or_single_client(auto_encryption_opts=opts)
# Create the encrypted field's data key.
key_vault = create_key_vault(
@ -342,8 +365,7 @@ class TestClientSimple(EncryptionIntegrationTest):
def test_use_after_close(self):
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
client = rs_or_single_client(auto_encryption_opts=opts)
self.addCleanup(client.close)
client = self.rs_or_single_client(auto_encryption_opts=opts)
client.admin.command("ping")
client.close()
@ -358,10 +380,10 @@ class TestClientSimple(EncryptionIntegrationTest):
is_greenthread_patched(),
"gevent and eventlet do not support POSIX-style forking.",
)
@client_context.require_sync
def test_fork(self):
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
client = rs_or_single_client(auto_encryption_opts=opts)
self.addCleanup(client.close)
client = self.rs_or_single_client(auto_encryption_opts=opts)
def target():
with warnings.catch_warnings():
@ -375,8 +397,7 @@ class TestClientSimple(EncryptionIntegrationTest):
class TestEncryptedBulkWrite(BulkTestBase, EncryptionIntegrationTest):
def test_upsert_uuid_standard_encrypt(self):
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
client = rs_or_single_client(auto_encryption_opts=opts)
self.addCleanup(client.close)
client = self.rs_or_single_client(auto_encryption_opts=opts)
options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
encrypted_coll = client.pymongo_test.test
@ -416,8 +437,7 @@ class TestClientMaxWireVersion(IntegrationTest):
@client_context.require_version_max(4, 0, 99)
def test_raise_max_wire_version_error(self):
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
client = rs_or_single_client(auto_encryption_opts=opts)
self.addCleanup(client.close)
client = self.rs_or_single_client(auto_encryption_opts=opts)
msg = "Auto-encryption requires a minimum MongoDB version of 4.2"
with self.assertRaisesRegex(ConfigurationError, msg):
client.test.test.insert_one({})
@ -430,8 +450,7 @@ class TestClientMaxWireVersion(IntegrationTest):
def test_raise_unsupported_error(self):
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
client = rs_or_single_client(auto_encryption_opts=opts)
self.addCleanup(client.close)
client = self.rs_or_single_client(auto_encryption_opts=opts)
msg = "find_raw_batches does not support auto encryption"
with self.assertRaisesRegex(InvalidOperation, msg):
client.test.test.find_raw_batches({})
@ -450,10 +469,9 @@ class TestClientMaxWireVersion(IntegrationTest):
class TestExplicitSimple(EncryptionIntegrationTest):
def test_encrypt_decrypt(self):
client_encryption = ClientEncryption(
client_encryption = self.create_client_encryption(
KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS
)
self.addCleanup(client_encryption.close)
# Use standard UUID representation.
key_vault = client_context.client.keyvault.get_collection("datakeys", codec_options=OPTS)
self.addCleanup(key_vault.drop)
@ -493,10 +511,9 @@ class TestExplicitSimple(EncryptionIntegrationTest):
self.assertEqual(decrypted_ssn, doc["ssn"])
def test_validation(self):
client_encryption = ClientEncryption(
client_encryption = self.create_client_encryption(
KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS
)
self.addCleanup(client_encryption.close)
msg = "value to decrypt must be a bson.binary.Binary with subtype 6"
with self.assertRaisesRegex(TypeError, msg):
@ -510,10 +527,9 @@ class TestExplicitSimple(EncryptionIntegrationTest):
client_encryption.encrypt("str", algo, key_id=Binary(b"123"))
def test_bson_errors(self):
client_encryption = ClientEncryption(
client_encryption = self.create_client_encryption(
KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS
)
self.addCleanup(client_encryption.close)
# Attempt to encrypt an unencodable object.
unencodable_value = object()
@ -526,7 +542,7 @@ class TestExplicitSimple(EncryptionIntegrationTest):
def test_codec_options(self):
with self.assertRaisesRegex(TypeError, "codec_options must be"):
ClientEncryption(
self.create_client_encryption(
KMS_PROVIDERS,
"keyvault.datakeys",
client_context.client,
@ -534,10 +550,9 @@ class TestExplicitSimple(EncryptionIntegrationTest):
)
opts = CodecOptions(uuid_representation=UuidRepresentation.JAVA_LEGACY)
client_encryption_legacy = ClientEncryption(
client_encryption_legacy = self.create_client_encryption(
KMS_PROVIDERS, "keyvault.datakeys", client_context.client, opts
)
self.addCleanup(client_encryption_legacy.close)
# Create the encrypted field's data key.
key_id = client_encryption_legacy.create_data_key("local")
@ -552,10 +567,9 @@ class TestExplicitSimple(EncryptionIntegrationTest):
# Encrypt the same UUID with STANDARD codec options.
opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
client_encryption = ClientEncryption(
client_encryption = self.create_client_encryption(
KMS_PROVIDERS, "keyvault.datakeys", client_context.client, opts
)
self.addCleanup(client_encryption.close)
encrypted_standard = client_encryption.encrypt(
value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id
)
@ -571,7 +585,7 @@ class TestExplicitSimple(EncryptionIntegrationTest):
self.assertNotEqual(client_encryption.decrypt(encrypted_legacy), value)
def test_close(self):
client_encryption = ClientEncryption(
client_encryption = self.create_client_encryption(
KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS
)
client_encryption.close()
@ -587,7 +601,7 @@ class TestExplicitSimple(EncryptionIntegrationTest):
client_encryption.decrypt(Binary(b"", 6))
def test_with_statement(self):
with ClientEncryption(
with self.create_client_encryption(
KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS
) as client_encryption:
pass
@ -807,7 +821,7 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
def _setup_class(cls):
super()._setup_class()
cls.listener = OvertCommandListener()
cls.client = rs_or_single_client(event_listeners=[cls.listener])
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
cls.client.db.coll.drop()
cls.vault = create_key_vault(cls.client.keyvault.datakeys)
@ -829,10 +843,10 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
opts = AutoEncryptionOpts(
cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS
)
cls.client_encrypted = rs_or_single_client(
cls.client_encrypted = cls.unmanaged_rs_or_single_client(
auto_encryption_opts=opts, uuidRepresentation="standard"
)
cls.client_encryption = ClientEncryption(
cls.client_encryption = cls.unmanaged_create_client_encryption(
cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS
)
@ -919,8 +933,7 @@ class TestExternalKeyVault(EncryptionIntegrationTest):
# Configure the encrypted field via the local schema_map option.
schemas = {"db.coll": json_data("external", "external-schema.json")}
if with_external_key_vault:
key_vault_client = rs_or_single_client(username="fake-user", password="fake-pwd")
self.addCleanup(key_vault_client.close)
key_vault_client = self.rs_or_single_client(username="fake-user", password="fake-pwd")
else:
key_vault_client = client_context.client
opts = AutoEncryptionOpts(
@ -930,15 +943,13 @@ class TestExternalKeyVault(EncryptionIntegrationTest):
key_vault_client=key_vault_client,
)
client_encrypted = rs_or_single_client(
client_encrypted = self.rs_or_single_client(
auto_encryption_opts=opts, uuidRepresentation="standard"
)
self.addCleanup(client_encrypted.close)
client_encryption = ClientEncryption(
client_encryption = self.create_client_encryption(
self.kms_providers(), "keyvault.datakeys", key_vault_client, OPTS
)
self.addCleanup(client_encryption.close)
if with_external_key_vault:
# Authentication error.
@ -984,10 +995,9 @@ class TestViews(EncryptionIntegrationTest):
self.addCleanup(self.client.db.view.drop)
opts = AutoEncryptionOpts(self.kms_providers(), "keyvault.datakeys")
client_encrypted = rs_or_single_client(
client_encrypted = self.rs_or_single_client(
auto_encryption_opts=opts, uuidRepresentation="standard"
)
self.addCleanup(client_encrypted.close)
with self.assertRaisesRegex(EncryptionError, "cannot auto encrypt a view"):
client_encrypted.db.view.insert_one({})
@ -1044,17 +1054,15 @@ class TestCorpus(EncryptionIntegrationTest):
)
self.addCleanup(vault.drop)
client_encrypted = rs_or_single_client(auto_encryption_opts=opts)
self.addCleanup(client_encrypted.close)
client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts)
client_encryption = ClientEncryption(
client_encryption = self.create_client_encryption(
self.kms_providers(),
"keyvault.datakeys",
client_context.client,
OPTS,
kms_tls_options=KMS_TLS_OPTS,
)
self.addCleanup(client_encryption.close)
corpus = self.fix_up_curpus(json_data("corpus", "corpus.json"))
corpus_copied: SON = SON()
@ -1197,7 +1205,7 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
cls.listener = OvertCommandListener()
cls.client_encrypted = rs_or_single_client(
cls.client_encrypted = cls.unmanaged_rs_or_single_client(
auto_encryption_opts=opts, event_listeners=[cls.listener]
)
cls.coll_encrypted = cls.client_encrypted.db.coll
@ -1285,7 +1293,7 @@ class TestCustomEndpoint(EncryptionIntegrationTest):
"gcp": GCP_CREDS,
"kmip": KMIP_CREDS,
}
self.client_encryption = ClientEncryption(
self.client_encryption = self.create_client_encryption(
kms_providers=kms_providers,
key_vault_namespace="keyvault.datakeys",
key_vault_client=client_context.client,
@ -1297,7 +1305,7 @@ class TestCustomEndpoint(EncryptionIntegrationTest):
kms_providers_invalid["azure"]["identityPlatformEndpoint"] = "doesnotexist.invalid:443"
kms_providers_invalid["gcp"]["endpoint"] = "doesnotexist.invalid:443"
kms_providers_invalid["kmip"]["endpoint"] = "doesnotexist.local:5698"
self.client_encryption_invalid = ClientEncryption(
self.client_encryption_invalid = self.create_client_encryption(
kms_providers=kms_providers_invalid,
key_vault_namespace="keyvault.datakeys",
key_vault_client=client_context.client,
@ -1476,7 +1484,7 @@ class TestCustomEndpoint(EncryptionIntegrationTest):
self.client_encryption.create_data_key("kmip", key)
class AzureGCPEncryptionTestMixin:
class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
DEK = None
KMS_PROVIDER_MAP = None
KEYVAULT_DB = "keyvault"
@ -1488,7 +1496,7 @@ class AzureGCPEncryptionTestMixin:
create_key_vault(keyvault, self.DEK)
def _test_explicit(self, expectation):
client_encryption = ClientEncryption(
client_encryption = self.create_client_encryption(
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
client_context.client,
@ -1517,7 +1525,7 @@ class AzureGCPEncryptionTestMixin:
)
insert_listener = AllowListEventListener("insert")
client = rs_or_single_client(
client = self.rs_or_single_client(
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
)
self.addCleanup(client.close)
@ -1596,19 +1604,17 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#deadlock-tests
class TestDeadlockProse(EncryptionIntegrationTest):
def setUp(self):
self.client_test = rs_or_single_client(
self.client_test = self.rs_or_single_client(
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
)
self.addCleanup(self.client_test.close)
self.client_keyvault_listener = OvertCommandListener()
self.client_keyvault = rs_or_single_client(
self.client_keyvault = self.rs_or_single_client(
maxPoolSize=1,
readConcernLevel="majority",
w="majority",
event_listeners=[self.client_keyvault_listener],
)
self.addCleanup(self.client_keyvault.close)
self.client_test.keyvault.datakeys.drop()
self.client_test.db.coll.drop()
@ -1619,7 +1625,7 @@ class TestDeadlockProse(EncryptionIntegrationTest):
codec_options=OPTS,
)
client_encryption = ClientEncryption(
client_encryption = self.create_client_encryption(
kms_providers={"local": {"key": LOCAL_MASTER_KEY}},
key_vault_namespace="keyvault.datakeys",
key_vault_client=self.client_test,
@ -1635,7 +1641,7 @@ class TestDeadlockProse(EncryptionIntegrationTest):
self.optargs = ({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
def _run_test(self, max_pool_size, auto_encryption_opts):
client_encrypted = rs_or_single_client(
client_encrypted = self.rs_or_single_client(
readConcernLevel="majority",
w="majority",
maxPoolSize=max_pool_size,
@ -1653,8 +1659,6 @@ class TestDeadlockProse(EncryptionIntegrationTest):
result = client_encrypted.db.coll.find_one({"_id": 0})
self.assertEqual(result, {"_id": 0, "encrypted": "string0"})
self.addCleanup(client_encrypted.close)
def test_case_1(self):
self._run_test(
max_pool_size=1,
@ -1830,7 +1834,7 @@ class TestDecryptProse(EncryptionIntegrationTest):
create_key_vault(self.client.keyvault.datakeys)
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
self.client_encryption = ClientEncryption(
self.client_encryption = self.create_client_encryption(
kms_providers_map, "keyvault.datakeys", self.client, CodecOptions()
)
keyID = self.client_encryption.create_data_key("local")
@ -1845,10 +1849,9 @@ class TestDecryptProse(EncryptionIntegrationTest):
key_vault_namespace="keyvault.datakeys", kms_providers=kms_providers_map
)
self.listener = AllowListEventListener("aggregate")
self.encrypted_client = rs_or_single_client(
self.encrypted_client = self.rs_or_single_client(
auto_encryption_opts=opts, retryReads=False, event_listeners=[self.listener]
)
self.addCleanup(self.encrypted_client.close)
def test_01_command_error(self):
with self.fail_point(
@ -1925,8 +1928,7 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest):
"--port=27027",
],
)
client_encrypted = rs_or_single_client(auto_encryption_opts=opts)
self.addCleanup(client_encrypted.close)
client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts)
with self.assertRaisesRegex(EncryptionError, "Timeout"):
client_encrypted.db.coll.insert_one({"encrypted": "test"})
@ -1940,11 +1942,12 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest):
"--port=27027",
],
)
client_encrypted = rs_or_single_client(auto_encryption_opts=opts)
self.addCleanup(client_encrypted.close)
client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts)
client_encrypted.db.coll.insert_one({"unencrypted": "test"})
# Validate that mongocryptd was not spawned:
mongocryptd_client = MongoClient("mongodb://localhost:27027/?serverSelectionTimeoutMS=500")
mongocryptd_client = self.simple_client(
"mongodb://localhost:27027/?serverSelectionTimeoutMS=500"
)
with self.assertRaises(ServerSelectionTimeoutError):
mongocryptd_client.admin.command("ping")
@ -1966,15 +1969,13 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest):
],
crypt_shared_lib_required=True,
)
client_encrypted = rs_or_single_client(auto_encryption_opts=opts)
self.addCleanup(client_encrypted.close)
client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts)
client_encrypted.db.coll.drop()
client_encrypted.db.coll.insert_one({"encrypted": "test"})
self.assertEncrypted((client_context.client.db.coll.find_one({}))["encrypted"])
no_mongocryptd_client = MongoClient(
no_mongocryptd_client = self.simple_client(
host="mongodb://localhost:47021/db?serverSelectionTimeoutMS=1000"
)
self.addCleanup(no_mongocryptd_client.close)
with self.assertRaises(ServerSelectionTimeoutError):
no_mongocryptd_client.db.command("ping")
@ -2008,8 +2009,7 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest):
mongocryptd_uri="mongodb://localhost:47021",
crypt_shared_lib_required=False,
)
client_encrypted = rs_or_single_client(auto_encryption_opts=opts)
self.addCleanup(client_encrypted.close)
client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts)
client_encrypted.db.coll.drop()
client_encrypted.db.coll.insert_one({"encrypted": "test"})
server.shutdown()
@ -2023,10 +2023,9 @@ class TestKmsTLSProse(EncryptionIntegrationTest):
def setUp(self):
super().setUp()
self.patch_system_certs(CA_PEM)
self.client_encrypted = ClientEncryption(
self.client_encrypted = self.create_client_encryption(
{"aws": AWS_CREDS}, "keyvault.datakeys", self.client, OPTS
)
self.addCleanup(self.client_encrypted.close)
def test_invalid_kms_certificate_expired(self):
key = {
@ -2071,36 +2070,32 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
"gcp": {"tlsCAFile": CA_PEM},
"kmip": {"tlsCAFile": CA_PEM},
}
self.client_encryption_no_client_cert = ClientEncryption(
self.client_encryption_no_client_cert = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only
)
self.addCleanup(self.client_encryption_no_client_cert.close)
# 2, same providers as above but with tlsCertificateKeyFile.
kms_tls_opts = copy.deepcopy(kms_tls_opts_ca_only)
for p in kms_tls_opts:
kms_tls_opts[p]["tlsCertificateKeyFile"] = CLIENT_PEM
self.client_encryption_with_tls = ClientEncryption(
self.client_encryption_with_tls = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts
)
self.addCleanup(self.client_encryption_with_tls.close)
# 3, update endpoints to expired host.
providers: dict = copy.deepcopy(providers)
providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9000"
providers["gcp"]["endpoint"] = "127.0.0.1:9000"
providers["kmip"]["endpoint"] = "127.0.0.1:9000"
self.client_encryption_expired = ClientEncryption(
self.client_encryption_expired = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only
)
self.addCleanup(self.client_encryption_expired.close)
# 3, update endpoints to invalid host.
providers: dict = copy.deepcopy(providers)
providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9001"
providers["gcp"]["endpoint"] = "127.0.0.1:9001"
providers["kmip"]["endpoint"] = "127.0.0.1:9001"
self.client_encryption_invalid_hostname = ClientEncryption(
self.client_encryption_invalid_hostname = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only
)
self.addCleanup(self.client_encryption_invalid_hostname.close)
# Errors when client has no cert, some examples:
# [SSL: TLSV13_ALERT_CERTIFICATE_REQUIRED] tlsv13 alert certificate required (_ssl.c:2623)
self.cert_error = (
@ -2138,7 +2133,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
"gcp:with_tls": with_cert,
"kmip:with_tls": with_cert,
}
self.client_encryption_with_names = ClientEncryption(
self.client_encryption_with_names = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_4
)
@ -2220,10 +2215,9 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
def test_05_tlsDisableOCSPEndpointCheck_is_permitted(self):
providers = {"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}}
options = {"aws": {"tlsDisableOCSPEndpointCheck": True}}
encryption = ClientEncryption(
encryption = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options
)
self.addCleanup(encryption.close)
ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"]
if not hasattr(ctx, "check_ocsp_endpoint"):
raise self.skipTest("OCSP not enabled")
@ -2273,7 +2267,7 @@ class TestUniqueIndexOnKeyAltNamesProse(EncryptionIntegrationTest):
self.client = client_context.client
create_key_vault(self.client.keyvault.datakeys)
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
self.client_encryption = ClientEncryption(
self.client_encryption = self.create_client_encryption(
kms_providers_map, "keyvault.datakeys", self.client, CodecOptions()
)
self.def_key_id = self.client_encryption.create_data_key("local", key_alt_names=["def"])
@ -2311,17 +2305,15 @@ class TestExplicitQueryableEncryption(EncryptionIntegrationTest):
key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document)
self.addCleanup(key_vault.drop)
self.key_vault_client = self.client
self.client_encryption = ClientEncryption(
self.client_encryption = self.create_client_encryption(
{"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS
)
self.addCleanup(self.client_encryption.close)
opts = AutoEncryptionOpts(
{"local": {"key": LOCAL_MASTER_KEY}},
key_vault.full_name,
bypass_query_analysis=True,
)
self.encrypted_client = rs_or_single_client(auto_encryption_opts=opts)
self.addCleanup(self.encrypted_client.close)
self.encrypted_client = self.rs_or_single_client(auto_encryption_opts=opts)
def test_01_insert_encrypted_indexed_and_find(self):
val = "encrypted indexed value"
@ -2444,14 +2436,13 @@ class TestRewrapWithSeparateClientEncryption(EncryptionIntegrationTest):
self.client.keyvault.drop_collection("datakeys")
# Step 2. Create a ``ClientEncryption`` object named ``client_encryption1``
client_encryption1 = ClientEncryption(
client_encryption1 = self.create_client_encryption(
key_vault_client=self.client,
key_vault_namespace="keyvault.datakeys",
kms_providers=ALL_KMS_PROVIDERS,
kms_tls_options=KMS_TLS_OPTS,
codec_options=OPTS,
)
self.addCleanup(client_encryption1.close)
# Step 3. Call ``client_encryption1.create_data_key`` with ``src_provider``.
key_id = client_encryption1.create_data_key(
@ -2464,16 +2455,14 @@ class TestRewrapWithSeparateClientEncryption(EncryptionIntegrationTest):
)
# Step 5. Create a ``ClientEncryption`` object named ``client_encryption2``
client2 = rs_or_single_client()
self.addCleanup(client2.close)
client_encryption2 = ClientEncryption(
client2 = self.rs_or_single_client()
client_encryption2 = self.create_client_encryption(
key_vault_client=client2,
key_vault_namespace="keyvault.datakeys",
kms_providers=ALL_KMS_PROVIDERS,
kms_tls_options=KMS_TLS_OPTS,
codec_options=OPTS,
)
self.addCleanup(client_encryption2.close)
# Step 6. Call ``client_encryption2.rewrap_many_data_key`` with an empty ``filter``.
rewrap_many_data_key_result = client_encryption2.rewrap_many_data_key(
@ -2508,7 +2497,7 @@ class TestOnDemandAWSCredentials(EncryptionIntegrationTest):
@unittest.skipIf(any(AWS_CREDS.values()), "AWS environment credentials are set")
def test_01_failure(self):
self.client_encryption = ClientEncryption(
self.client_encryption = self.create_client_encryption(
kms_providers={"aws": {}},
key_vault_namespace="keyvault.datakeys",
key_vault_client=client_context.client,
@ -2519,7 +2508,7 @@ class TestOnDemandAWSCredentials(EncryptionIntegrationTest):
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
def test_02_success(self):
self.client_encryption = ClientEncryption(
self.client_encryption = self.create_client_encryption(
kms_providers={"aws": {}},
key_vault_namespace="keyvault.datakeys",
key_vault_client=client_context.client,
@ -2539,8 +2528,7 @@ class TestQueryableEncryptionDocsExample(EncryptionIntegrationTest):
# MongoClient to use in testing that handles auth/tls/etc,
# and cleanup.
def MongoClient(**kwargs):
c = rs_or_single_client(**kwargs)
self.addCleanup(c.close)
c = self.rs_or_single_client(**kwargs)
return c
# Drop data from prior test runs.
@ -2551,7 +2539,7 @@ class TestQueryableEncryptionDocsExample(EncryptionIntegrationTest):
# Create two data keys.
key_vault_client = MongoClient()
client_encryption = ClientEncryption(
client_encryption = self.create_client_encryption(
kms_providers_map, "keyvault.datakeys", key_vault_client, CodecOptions()
)
key1_id = client_encryption.create_data_key("local")
@ -2632,18 +2620,16 @@ class TestRangeQueryProse(EncryptionIntegrationTest):
key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document)
self.addCleanup(key_vault.drop)
self.key_vault_client = self.client
self.client_encryption = ClientEncryption(
self.client_encryption = self.create_client_encryption(
{"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS
)
self.addCleanup(self.client_encryption.close)
opts = AutoEncryptionOpts(
{"local": {"key": LOCAL_MASTER_KEY}},
key_vault.full_name,
bypass_query_analysis=True,
)
self.encrypted_client = rs_or_single_client(auto_encryption_opts=opts)
self.encrypted_client = self.rs_or_single_client(auto_encryption_opts=opts)
self.db = self.encrypted_client.db
self.addCleanup(self.encrypted_client.close)
def run_expression_find(
self, name, expression, expected_elems, range_opts, use_expr=False, key_id=None
@ -2838,10 +2824,9 @@ class TestRangeQueryDefaultsProse(EncryptionIntegrationTest):
super().setUp()
self.client.drop_database(self.db)
self.key_vault_client = self.client
self.client_encryption = ClientEncryption(
self.client_encryption = self.create_client_encryption(
{"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys", self.key_vault_client, OPTS
)
self.addCleanup(self.client_encryption.close)
self.key_id = self.client_encryption.create_data_key("local")
opts = RangeOpts(min=0, max=1000)
self.payload_defaults = self.client_encryption.encrypt(
@ -2874,13 +2859,12 @@ class TestAutomaticDecryptionKeys(EncryptionIntegrationTest):
self.client.drop_database(self.db)
self.key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document)
self.addCleanup(self.key_vault.drop)
self.client_encryption = ClientEncryption(
self.client_encryption = self.create_client_encryption(
{"local": {"key": LOCAL_MASTER_KEY}},
self.key_vault.full_name,
self.client,
OPTS,
)
self.addCleanup(self.client_encryption.close)
def test_01_simple_create(self):
coll, _ = self.client_encryption.create_encrypted_collection(
@ -3096,10 +3080,9 @@ class TestNoSessionsSupport(EncryptionIntegrationTest):
def setUp(self) -> None:
self.listener = OvertCommandListener()
self.mongocryptd_client = MongoClient(
self.mongocryptd_client = self.simple_client(
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]
)
self.addCleanup(self.mongocryptd_client.close)
hello = self.mongocryptd_client.db.command("hello")
self.assertNotIn("logicalSessionTimeoutMinutes", hello)

View File

@ -22,7 +22,7 @@ import threading
sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test.utils import rs_client, wait_until
from test.utils import wait_until
import pymongo
from pymongo.errors import ConnectionFailure, OperationFailure
@ -1128,7 +1128,7 @@ class TestTransactionExamples(IntegrationTest):
self.assertEqual(employee["status"], "Inactive")
def MongoClient(_):
return rs_client()
return self.rs_client()
uriString = None
@ -1220,7 +1220,7 @@ class TestVersionedApiExamples(IntegrationTest):
def test_versioned_api(self):
# Versioned API examples
def MongoClient(_, server_api):
return rs_client(server_api=server_api, connect=False)
return self.rs_client(server_api=server_api, connect=False)
uri = None
@ -1251,7 +1251,7 @@ class TestVersionedApiExamples(IntegrationTest):
):
self.skipTest("This test needs MongoDB 5.0.2 or newer")
client = rs_client(server_api=ServerApi("1", strict=True))
client = self.rs_client(server_api=ServerApi("1", strict=True))
client.db.sales.drop()
# Start Versioned API Example 5

View File

@ -33,7 +33,7 @@ from pymongo.synchronous.database import Database
sys.path[0:0] = [""]
from test.utils import EventListener, rs_or_single_client
from test.utils import EventListener
from bson.objectid import ObjectId
from gridfs.errors import NoFile
@ -790,7 +790,7 @@ Bye"""
outfile.readchunk()
def test_grid_in_lazy_connect(self):
client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10)
client = self.simple_client("badhost", connect=False, serverSelectionTimeoutMS=10)
fs = client.db.fs
infile = GridIn(fs, file_id=-1, chunk_size=1)
with self.assertRaises(ServerSelectionTimeoutError):
@ -801,7 +801,7 @@ Bye"""
def test_unacknowledged(self):
# w=0 is prohibited.
with self.assertRaises(ConfigurationError):
GridIn((rs_or_single_client(w=0)).pymongo_test.fs)
GridIn((self.rs_or_single_client(w=0)).pymongo_test.fs)
def test_survive_cursor_not_found(self):
# By default the find command returns 101 documents in the first batch.
@ -809,7 +809,7 @@ Bye"""
chunk_size = 1024
data = b"d" * (102 * chunk_size)
listener = EventListener()
client = rs_or_single_client(event_listeners=[listener])
client = self.rs_or_single_client(event_listeners=[listener])
db = client.pymongo_test
with GridIn(db.fs, chunk_size=chunk_size) as infile:
infile.write(data)

View File

@ -26,7 +26,7 @@ from unittest.mock import patch
sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test.utils import joinall, one, rs_client, rs_or_single_client, single_client
from test.utils import joinall, one
import gridfs
from bson.binary import Binary
@ -411,7 +411,7 @@ class TestGridfs(IntegrationTest):
self.assertTrue(iterate_file(f))
def test_gridfs_lazy_connect(self):
client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10)
client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=10)
db = client.db
gfs = gridfs.GridFS(db)
self.assertRaises(ServerSelectionTimeoutError, gfs.list)
@ -492,7 +492,7 @@ class TestGridfs(IntegrationTest):
def test_unacknowledged(self):
# w=0 is prohibited.
with self.assertRaises(ConfigurationError):
gridfs.GridFS(rs_or_single_client(w=0).pymongo_test)
gridfs.GridFS(self.rs_or_single_client(w=0).pymongo_test)
def test_md5(self):
gin = self.fs.new_file()
@ -519,7 +519,7 @@ class TestGridfsReplicaSet(IntegrationTest):
client_context.client.drop_database("gfsreplica")
def test_gridfs_replica_set(self):
rsc = rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY)
rsc = self.rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY)
fs = gridfs.GridFS(rsc.gfsreplica, "gfsreplicatest")
@ -532,7 +532,7 @@ class TestGridfsReplicaSet(IntegrationTest):
def test_gridfs_secondary(self):
secondary_host, secondary_port = one(self.client.secondaries)
secondary_connection = single_client(
secondary_connection = self.single_client(
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY
)
@ -547,7 +547,7 @@ class TestGridfsReplicaSet(IntegrationTest):
# Should detect it's connected to secondary and not attempt to
# create index.
secondary_host, secondary_port = one(self.client.secondaries)
client = single_client(
client = self.single_client(
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False
)

View File

@ -27,7 +27,7 @@ from unittest.mock import patch
sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test.utils import joinall, one, rs_client, rs_or_single_client, single_client
from test.utils import joinall, one
import gridfs
from bson.binary import Binary
@ -345,7 +345,7 @@ class TestGridfs(IntegrationTest):
self.assertTrue(iterate_file(fstr))
def test_gridfs_lazy_connect(self):
client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=0)
client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=0)
cdb = client.db
gfs = gridfs.GridFSBucket(cdb)
self.assertRaises(ServerSelectionTimeoutError, gfs.delete, 0)
@ -391,7 +391,7 @@ class TestGridfs(IntegrationTest):
def test_unacknowledged(self):
# w=0 is prohibited.
with self.assertRaises(ConfigurationError):
gridfs.GridFSBucket(rs_or_single_client(w=0).pymongo_test)
gridfs.GridFSBucket(self.rs_or_single_client(w=0).pymongo_test)
def test_rename(self):
_id = self.fs.upload_from_stream("first_name", b"testing")
@ -489,7 +489,7 @@ class TestGridfsBucketReplicaSet(IntegrationTest):
client_context.client.drop_database("gfsbucketreplica")
def test_gridfs_replica_set(self):
rsc = rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY)
rsc = self.rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY)
gfs = gridfs.GridFSBucket(rsc.gfsbucketreplica, "gfsbucketreplicatest")
oid = gfs.upload_from_stream("test_filename", b"foo")
@ -498,7 +498,7 @@ class TestGridfsBucketReplicaSet(IntegrationTest):
def test_gridfs_secondary(self):
secondary_host, secondary_port = one(self.client.secondaries)
secondary_connection = single_client(
secondary_connection = self.single_client(
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY
)
@ -513,7 +513,7 @@ class TestGridfsBucketReplicaSet(IntegrationTest):
# Should detect it's connected to secondary and not attempt to
# create index.
secondary_host, secondary_port = one(self.client.secondaries)
client = single_client(
client = self.single_client(
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False
)

View File

@ -20,7 +20,7 @@ import sys
sys.path[0:0] = [""]
from test import IntegrationTest, client_knobs, unittest
from test.utils import HeartbeatEventListener, MockPool, single_client, wait_until
from test.utils import HeartbeatEventListener, MockPool, wait_until
from pymongo.errors import ConnectionFailure
from pymongo.hello import Hello, HelloCompat
@ -40,7 +40,7 @@ class TestHeartbeatMonitoring(IntegrationTest):
raise responses[1]
return Hello(responses[1]), 99
m = single_client(
m = self.single_client(
h=uri, event_listeners=(listener,), _monitor_class=MockMonitor, _pool_class=MockPool
)

View File

@ -26,7 +26,7 @@ sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test.unified_format import generate_test_classes
from test.utils import ExceptionCatchingThread, get_pool, rs_client, wait_until
from test.utils import ExceptionCatchingThread, get_pool, wait_until
pytestmark = pytest.mark.load_balancer
@ -54,7 +54,7 @@ class TestLB(IntegrationTest):
@client_context.require_load_balancer
def test_unpin_committed_transaction(self):
client = rs_client()
client = self.rs_client()
self.addCleanup(client.close)
pool = get_pool(client)
coll = client[self.db.name].test
@ -85,7 +85,7 @@ class TestLB(IntegrationTest):
self._test_no_gc_deadlock(create_resource)
def _test_no_gc_deadlock(self, create_resource):
client = rs_client()
client = self.rs_client()
self.addCleanup(client.close)
pool = get_pool(client)
coll = client[self.db.name].test
@ -124,7 +124,7 @@ class TestLB(IntegrationTest):
@client_context.require_transactions
def test_session_gc(self):
client = rs_client()
client = self.rs_client()
self.addCleanup(client.close)
pool = get_pool(client)
session = client.start_session()

View File

@ -15,7 +15,6 @@ from __future__ import annotations
import os
from test import IntegrationTest, unittest
from test.utils import single_client
from unittest.mock import patch
from bson import json_util
@ -85,7 +84,7 @@ class TestLogger(IntegrationTest):
self.assertEqual(last_3_bytes, str_to_repeat)
def test_logging_without_listeners(self):
c = single_client()
c = self.single_client()
self.assertEqual(len(c._event_listeners.event_listeners()), 0)
with self.assertLogs("pymongo.connection", level="DEBUG") as cm:
c.db.test.insert_one({"x": "1"})

View File

@ -20,15 +20,14 @@ import sys
import time
import warnings
from pymongo import MongoClient
from pymongo.operations import _Op
sys.path[0:0] = [""]
from test import client_context, unittest
from test.utils import rs_or_single_client
from test import PyMongoTestCase, client_context, unittest
from test.utils_selection_tests import create_selection_tests
from pymongo import MongoClient
from pymongo.errors import ConfigurationError
from pymongo.server_selectors import writable_server_selector
@ -40,54 +39,58 @@ class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore
pass
class TestMaxStaleness(unittest.TestCase):
class TestMaxStaleness(PyMongoTestCase):
def test_max_staleness(self):
client = MongoClient()
client = self.simple_client()
self.assertEqual(-1, client.read_preference.max_staleness)
client = MongoClient("mongodb://a/?readPreference=secondary")
client = self.simple_client("mongodb://a/?readPreference=secondary")
self.assertEqual(-1, client.read_preference.max_staleness)
# These tests are specified in max-staleness-tests.rst.
with self.assertRaises(ConfigurationError):
# Default read pref "primary" can't be used with max staleness.
MongoClient("mongodb://a/?maxStalenessSeconds=120")
self.simple_client("mongodb://a/?maxStalenessSeconds=120")
with self.assertRaises(ConfigurationError):
# Read pref "primary" can't be used with max staleness.
MongoClient("mongodb://a/?readPreference=primary&maxStalenessSeconds=120")
self.simple_client("mongodb://a/?readPreference=primary&maxStalenessSeconds=120")
client = MongoClient("mongodb://host/?maxStalenessSeconds=-1")
client = self.simple_client("mongodb://host/?maxStalenessSeconds=-1")
self.assertEqual(-1, client.read_preference.max_staleness)
client = MongoClient("mongodb://host/?readPreference=primary&maxStalenessSeconds=-1")
client = self.simple_client("mongodb://host/?readPreference=primary&maxStalenessSeconds=-1")
self.assertEqual(-1, client.read_preference.max_staleness)
client = MongoClient("mongodb://host/?readPreference=secondary&maxStalenessSeconds=120")
client = self.simple_client(
"mongodb://host/?readPreference=secondary&maxStalenessSeconds=120"
)
self.assertEqual(120, client.read_preference.max_staleness)
client = MongoClient("mongodb://a/?readPreference=secondary&maxStalenessSeconds=1")
client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=1")
self.assertEqual(1, client.read_preference.max_staleness)
client = MongoClient("mongodb://a/?readPreference=secondary&maxStalenessSeconds=-1")
client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=-1")
self.assertEqual(-1, client.read_preference.max_staleness)
client = MongoClient(maxStalenessSeconds=-1, readPreference="nearest")
client = self.simple_client(maxStalenessSeconds=-1, readPreference="nearest")
self.assertEqual(-1, client.read_preference.max_staleness)
with self.assertRaises(TypeError):
# Prohibit None.
MongoClient(maxStalenessSeconds=None, readPreference="nearest")
self.simple_client(maxStalenessSeconds=None, readPreference="nearest")
def test_max_staleness_float(self):
with self.assertRaises(TypeError) as ctx:
rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest")
self.rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest")
self.assertIn("must be an integer", str(ctx.exception))
with warnings.catch_warnings(record=True) as ctx:
warnings.simplefilter("always")
client = MongoClient("mongodb://host/?maxStalenessSeconds=1.5&readPreference=nearest")
client = self.simple_client(
"mongodb://host/?maxStalenessSeconds=1.5&readPreference=nearest"
)
# Option was ignored.
self.assertEqual(-1, client.read_preference.max_staleness)
@ -96,13 +99,15 @@ class TestMaxStaleness(unittest.TestCase):
def test_max_staleness_zero(self):
# Zero is too small.
with self.assertRaises(ValueError) as ctx:
rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest")
self.rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest")
self.assertIn("must be a positive integer", str(ctx.exception))
with warnings.catch_warnings(record=True) as ctx:
warnings.simplefilter("always")
client = MongoClient("mongodb://host/?maxStalenessSeconds=0&readPreference=nearest")
client = self.simple_client(
"mongodb://host/?maxStalenessSeconds=0&readPreference=nearest"
)
# Option was ignored.
self.assertEqual(-1, client.read_preference.max_staleness)
@ -111,7 +116,7 @@ class TestMaxStaleness(unittest.TestCase):
@client_context.require_replica_set
def test_last_write_date(self):
# From max-staleness-tests.rst, "Parse lastWriteDate".
client = rs_or_single_client(heartbeatFrequencyMS=500)
client = self.rs_or_single_client(heartbeatFrequencyMS=500)
client.pymongo_test.test.insert_one({})
# Wait for the server description to be updated.
time.sleep(1)

View File

@ -18,6 +18,7 @@ from __future__ import annotations
import gc
import subprocess
import sys
import warnings
from functools import partial
sys.path[0:0] = [""]
@ -25,7 +26,6 @@ sys.path[0:0] = [""]
from test import IntegrationTest, connected, unittest
from test.utils import (
ServerAndTopologyEventListener,
single_client,
wait_until,
)
@ -47,30 +47,31 @@ def get_executors(client):
return [e for e in executors if e is not None]
def create_client():
listener = ServerAndTopologyEventListener()
client = single_client(event_listeners=[listener])
connected(client)
return client
class TestMonitor(IntegrationTest):
def create_client(self):
listener = ServerAndTopologyEventListener()
client = self.unmanaged_single_client(event_listeners=[listener])
connected(client)
return client
def test_cleanup_executors_on_client_del(self):
client = create_client()
executors = get_executors(client)
self.assertEqual(len(executors), 4)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
client = self.create_client()
executors = get_executors(client)
self.assertEqual(len(executors), 4)
# Each executor stores a weakref to itself in _EXECUTORS.
executor_refs = [(r, r()._name) for r in _EXECUTORS.copy() if r() in executors]
# Each executor stores a weakref to itself in _EXECUTORS.
executor_refs = [(r, r()._name) for r in _EXECUTORS.copy() if r() in executors]
del executors
del client
del executors
del client
for ref, name in executor_refs:
wait_until(partial(unregistered, ref), f"unregister executor: {name}", timeout=5)
for ref, name in executor_refs:
wait_until(partial(unregistered, ref), f"unregister executor: {name}", timeout=5)
def test_cleanup_executors_on_client_close(self):
client = create_client()
client = self.create_client()
executors = get_executors(client)
self.assertEqual(len(executors), 4)

View File

@ -31,8 +31,6 @@ from test import (
)
from test.utils import (
EventListener,
rs_or_single_client,
single_client,
wait_until,
)
@ -57,7 +55,9 @@ class TestCommandMonitoring(IntegrationTest):
def _setup_class(cls):
super()._setup_class()
cls.listener = EventListener()
cls.client = rs_or_single_client(event_listeners=[cls.listener], retryWrites=False)
cls.client = cls.unmanaged_rs_or_single_client(
event_listeners=[cls.listener], retryWrites=False
)
@classmethod
def _tearDown_class(cls):
@ -405,7 +405,7 @@ class TestCommandMonitoring(IntegrationTest):
@client_context.require_secondaries_count(1)
def test_not_primary_error(self):
address = next(iter(client_context.client.secondaries))
client = single_client(*address, event_listeners=[self.listener])
client = self.single_client(*address, event_listeners=[self.listener])
# Clear authentication command results from the listener.
client.admin.command("ping")
self.listener.reset()
@ -1144,7 +1144,7 @@ class TestGlobalListener(IntegrationTest):
# We plan to call register(), which internally modifies _LISTENERS.
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
monitoring.register(cls.listener)
cls.client = single_client()
cls.client = cls.unmanaged_single_client()
# Get one (authenticated) socket in the pool.
cls.client.pymongo_test.command("ping")

View File

@ -31,7 +31,7 @@ from pymongo.hello import HelloCompat
sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test.utils import delay, get_pool, joinall, rs_or_single_client
from test.utils import delay, get_pool, joinall
from pymongo.socket_checker import SocketChecker
from pymongo.synchronous.pool import Pool, PoolOptions
@ -151,7 +151,7 @@ class _TestPoolingBase(IntegrationTest):
def setUp(self):
super().setUp()
self.c = rs_or_single_client()
self.c = self.rs_or_single_client()
db = self.c[DB]
db.unique.drop()
db.test.drop()
@ -378,7 +378,7 @@ class TestPooling(_TestPoolingBase):
socket_info.close_conn(None)
def test_maxConnecting(self):
client = rs_or_single_client()
client = self.rs_or_single_client()
self.addCleanup(client.close)
self.client.test.test.insert_one({})
self.addCleanup(self.client.test.test.delete_many, {})
@ -415,7 +415,7 @@ class TestPooling(_TestPoolingBase):
@client_context.require_failCommand_appName
def test_csot_timeout_message(self):
client = rs_or_single_client(appName="connectionTimeoutApp")
client = self.rs_or_single_client(appName="connectionTimeoutApp")
self.addCleanup(client.close)
# Mock an operation failing due to pymongo.timeout().
mock_connection_timeout = {
@ -440,7 +440,7 @@ class TestPooling(_TestPoolingBase):
@client_context.require_failCommand_appName
def test_socket_timeout_message(self):
client = rs_or_single_client(socketTimeoutMS=500, appName="connectionTimeoutApp")
client = self.rs_or_single_client(socketTimeoutMS=500, appName="connectionTimeoutApp")
self.addCleanup(client.close)
# Mock an operation failing due to socketTimeoutMS.
mock_connection_timeout = {
@ -479,7 +479,7 @@ class TestPooling(_TestPoolingBase):
},
}
client = rs_or_single_client(
client = self.rs_or_single_client(
connectTimeoutMS=500,
socketTimeoutMS=500,
appName="connectionTimeoutApp",
@ -502,7 +502,7 @@ class TestPooling(_TestPoolingBase):
class TestPoolMaxSize(_TestPoolingBase):
def test_max_pool_size(self):
max_pool_size = 4
c = rs_or_single_client(maxPoolSize=max_pool_size)
c = self.rs_or_single_client(maxPoolSize=max_pool_size)
self.addCleanup(c.close)
collection = c[DB].test
@ -538,7 +538,7 @@ class TestPoolMaxSize(_TestPoolingBase):
self.assertEqual(0, cx_pool.requests)
def test_max_pool_size_none(self):
c = rs_or_single_client(maxPoolSize=None)
c = self.rs_or_single_client(maxPoolSize=None)
self.addCleanup(c.close)
collection = c[DB].test
@ -570,7 +570,7 @@ class TestPoolMaxSize(_TestPoolingBase):
self.assertEqual(cx_pool.max_pool_size, float("inf"))
def test_max_pool_size_zero(self):
c = rs_or_single_client(maxPoolSize=0)
c = self.rs_or_single_client(maxPoolSize=0)
self.addCleanup(c.close)
pool = get_pool(c)
self.assertEqual(pool.max_pool_size, float("inf"))

View File

@ -21,7 +21,7 @@ import unittest
sys.path[0:0] = [""]
from test import IntegrationTest, client_context
from test.utils import OvertCommandListener, rs_or_single_client
from test.utils import OvertCommandListener
from bson.son import SON
from pymongo.errors import OperationFailure
@ -36,7 +36,7 @@ class TestReadConcern(IntegrationTest):
def setUpClass(cls):
super().setUpClass()
cls.listener = OvertCommandListener()
cls.client = rs_or_single_client(event_listeners=[cls.listener])
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
cls.db = cls.client.pymongo_test
client_context.client.pymongo_test.create_collection("coll")
@ -67,7 +67,7 @@ class TestReadConcern(IntegrationTest):
def test_read_concern_uri(self):
uri = f"mongodb://{client_context.pair}/?readConcernLevel=majority"
client = rs_or_single_client(uri, connect=False)
client = self.rs_or_single_client(uri, connect=False)
self.assertEqual(ReadConcern("majority"), client.read_concern)
def test_invalid_read_concern(self):

View File

@ -30,8 +30,6 @@ from test import IntegrationTest, SkipTest, client_context, connected, unittest
from test.utils import (
OvertCommandListener,
one,
rs_client,
single_client,
wait_until,
)
from test.version import Version
@ -58,7 +56,7 @@ from pymongo.write_concern import WriteConcern
class TestSelections(IntegrationTest):
@client_context.require_connection
def test_bool(self):
client = single_client()
client = self.single_client()
wait_until(lambda: client.address, "discover primary")
selection = Selection.from_topology_description(client._topology.description)
@ -128,7 +126,7 @@ class TestReadPreferencesBase(IntegrationTest):
return None
def assertReadsFrom(self, expected, **kwargs):
c = rs_client(**kwargs)
c = self.rs_client(**kwargs)
wait_until(lambda: len(c.nodes - c.arbiters) == client_context.w, "discovered all nodes")
used = self.read_from_which_kind(c)
@ -139,7 +137,7 @@ class TestSingleSecondaryOk(TestReadPreferencesBase):
def test_reads_from_secondary(self):
host, port = next(iter(self.client.secondaries))
# Direct connection to a secondary.
client = single_client(host, port)
client = self.single_client(host, port)
self.assertFalse(client.is_primary)
# Regardless of read preference, we should be able to do
@ -175,19 +173,21 @@ class TestReadPreferences(TestReadPreferencesBase):
ReadPreference.SECONDARY_PREFERRED,
ReadPreference.NEAREST,
):
self.assertEqual(mode, rs_client(read_preference=mode).read_preference)
self.assertEqual(mode, self.rs_client(read_preference=mode).read_preference)
self.assertRaises(TypeError, rs_client, read_preference="foo")
self.assertRaises(TypeError, self.rs_client, read_preference="foo")
def test_tag_sets_validation(self):
S = Secondary(tag_sets=[{}])
self.assertEqual([{}], rs_client(read_preference=S).read_preference.tag_sets)
self.assertEqual([{}], self.rs_client(read_preference=S).read_preference.tag_sets)
S = Secondary(tag_sets=[{"k": "v"}])
self.assertEqual([{"k": "v"}], rs_client(read_preference=S).read_preference.tag_sets)
self.assertEqual([{"k": "v"}], self.rs_client(read_preference=S).read_preference.tag_sets)
S = Secondary(tag_sets=[{"k": "v"}, {}])
self.assertEqual([{"k": "v"}, {}], rs_client(read_preference=S).read_preference.tag_sets)
self.assertEqual(
[{"k": "v"}, {}], self.rs_client(read_preference=S).read_preference.tag_sets
)
self.assertRaises(ValueError, Secondary, tag_sets=[])
@ -200,20 +200,22 @@ class TestReadPreferences(TestReadPreferencesBase):
def test_threshold_validation(self):
self.assertEqual(
17, rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms
17, self.rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms
)
self.assertEqual(
42, rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms
42, self.rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms
)
self.assertEqual(
666, rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms
666, self.rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms
)
self.assertEqual(0, rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms)
self.assertEqual(
0, self.rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms
)
self.assertRaises(ValueError, rs_client, localthresholdms=-1)
self.assertRaises(ValueError, self.rs_client, localthresholdms=-1)
def test_zero_latency(self):
ping_times: set = set()
@ -223,7 +225,7 @@ class TestReadPreferences(TestReadPreferencesBase):
for ping_time, host in zip(ping_times, self.client.nodes):
ServerDescription._host_to_round_trip_time[host] = ping_time
try:
client = connected(rs_client(readPreference="nearest", localThresholdMS=0))
client = connected(self.rs_client(readPreference="nearest", localThresholdMS=0))
wait_until(lambda: client.nodes == self.client.nodes, "discovered all nodes")
host = self.read_from_which_host(client)
for _ in range(5):
@ -236,7 +238,7 @@ class TestReadPreferences(TestReadPreferencesBase):
def test_primary_with_tags(self):
# Tags not allowed with PRIMARY
self.assertRaises(ConfigurationError, rs_client, tag_sets=[{"dc": "ny"}])
self.assertRaises(ConfigurationError, self.rs_client, tag_sets=[{"dc": "ny"}])
def test_primary_preferred(self):
self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED)
@ -250,7 +252,9 @@ class TestReadPreferences(TestReadPreferencesBase):
def test_nearest(self):
# With high localThresholdMS, expect to read from any
# member
c = rs_client(read_preference=ReadPreference.NEAREST, localThresholdMS=10000) # 10 seconds
c = self.rs_client(
read_preference=ReadPreference.NEAREST, localThresholdMS=10000
) # 10 seconds
data_members = {self.client.primary} | self.client.secondaries
@ -540,7 +544,7 @@ class TestMongosAndReadPreference(IntegrationTest):
if client_context.supports_secondary_read_pref:
cases["secondary"] = Secondary
listener = OvertCommandListener()
client = rs_client(event_listeners=[listener])
client = self.rs_client(event_listeners=[listener])
self.addCleanup(client.close)
client.admin.command("ping")
for _mode, cls in cases.items():
@ -667,13 +671,13 @@ class TestMongosAndReadPreference(IntegrationTest):
else:
self.fail("mongos accepted invalid staleness")
coll = single_client(
coll = self.single_client(
readPreference="secondaryPreferred", maxStalenessSeconds=120
).pymongo_test.test
# No error
coll.find_one()
coll = single_client(
coll = self.single_client(
readPreference="secondaryPreferred", maxStalenessSeconds=10
).pymongo_test.test
try:

View File

@ -24,12 +24,7 @@ sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test.unified_format import generate_test_classes
from test.utils import (
EventListener,
disable_replication,
enable_replication,
rs_or_single_client,
)
from test.utils import EventListener
from pymongo import DESCENDING
from pymongo.errors import (
@ -51,7 +46,7 @@ class TestReadWriteConcernSpec(IntegrationTest):
def test_omit_default_read_write_concern(self):
listener = EventListener()
# Client with default readConcern and writeConcern
client = rs_or_single_client(event_listeners=[listener])
client = self.rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
collection = client.pymongo_test.collection
# Prepare for tests of find() and aggregate().
@ -104,7 +99,9 @@ class TestReadWriteConcernSpec(IntegrationTest):
def assertWriteOpsRaise(self, write_concern, expected_exception):
wc = write_concern.document
# Set socket timeout to avoid indefinite stalls
client = rs_or_single_client(w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000)
client = self.rs_or_single_client(
w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000
)
db = client.get_database("pymongo_test")
coll = db.test
@ -167,9 +164,9 @@ class TestReadWriteConcernSpec(IntegrationTest):
@client_context.require_test_commands
def test_raise_wtimeout(self):
self.addCleanup(client_context.client.drop_database, "pymongo_test")
self.addCleanup(enable_replication, client_context.client)
self.addCleanup(self.enable_replication, client_context.client)
# Disable replication to guarantee a wtimeout error.
disable_replication(client_context.client)
self.disable_replication(client_context.client)
self.assertWriteOpsRaise(WriteConcern(w=client_context.w, wtimeout=1), WTimeoutError)
@client_context.require_failCommand_fail_point
@ -209,7 +206,7 @@ class TestReadWriteConcernSpec(IntegrationTest):
@client_context.require_version_min(4, 9)
def test_write_error_details_exposes_errinfo(self):
listener = EventListener()
client = rs_or_single_client(event_listeners=[listener])
client = self.rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
db = client.errinfotest
self.addCleanup(client.drop_database, "errinfotest")

View File

@ -34,7 +34,6 @@ from test import (
from test.utils import (
CMAPListener,
OvertCommandListener,
rs_or_single_client,
set_fail_point,
)
@ -93,7 +92,9 @@ class TestPoolPausedError(IntegrationTest):
self.skipTest("Test is flakey on PyPy")
cmap_listener = CMAPListener()
cmd_listener = OvertCommandListener()
client = rs_or_single_client(maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener])
client = self.rs_or_single_client(
maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]
)
self.addCleanup(client.close)
for _ in range(10):
cmap_listener.reset()
@ -163,13 +164,13 @@ class TestRetryableReads(IntegrationTest):
mongos_clients = []
for mongos in client_context.mongos_seeds().split(","):
client = rs_or_single_client(mongos)
client = self.rs_or_single_client(mongos)
set_fail_point(client, fail_command)
self.addCleanup(client.close)
mongos_clients.append(client)
listener = OvertCommandListener()
client = rs_or_single_client(
client = self.rs_or_single_client(
client_context.mongos_seeds(),
appName="retryableReadTest",
event_listeners=[listener],

View File

@ -28,7 +28,6 @@ from test.utils import (
DeprecationFilter,
EventListener,
OvertCommandListener,
rs_or_single_client,
set_fail_point,
)
from test.version import Version
@ -145,7 +144,7 @@ class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
cls.client = rs_or_single_client(retryWrites=True)
cls.client = cls.unmanaged_rs_or_single_client(retryWrites=True)
cls.db = cls.client.pymongo_test
@classmethod
@ -181,7 +180,9 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
cls.listener = OvertCommandListener()
cls.client = rs_or_single_client(retryWrites=True, event_listeners=[cls.listener])
cls.client = cls.unmanaged_rs_or_single_client(
retryWrites=True, event_listeners=[cls.listener]
)
cls.db = cls.client.pymongo_test
@classmethod
@ -204,7 +205,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
def test_supported_single_statement_no_retry(self):
listener = OvertCommandListener()
client = rs_or_single_client(retryWrites=False, event_listeners=[listener])
client = self.rs_or_single_client(retryWrites=False, event_listeners=[listener])
self.addCleanup(client.close)
for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test):
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
@ -297,7 +298,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
def test_server_selection_timeout_not_retried(self):
"""A ServerSelectionTimeoutError is not retried."""
listener = OvertCommandListener()
client = MongoClient(
client = self.simple_client(
"somedomainthatdoesntexist.org",
serverSelectionTimeoutMS=1,
retryWrites=True,
@ -317,7 +318,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
original error.
"""
listener = OvertCommandListener()
client = rs_or_single_client(retryWrites=True, event_listeners=[listener])
client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener])
self.addCleanup(client.close)
topology = client._topology
select_server = topology.select_server
@ -443,13 +444,13 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
mongos_clients = []
for mongos in client_context.mongos_seeds().split(","):
client = rs_or_single_client(mongos)
client = self.rs_or_single_client(mongos)
set_fail_point(client, fail_command)
self.addCleanup(client.close)
mongos_clients.append(client)
listener = OvertCommandListener()
client = rs_or_single_client(
client = self.rs_or_single_client(
client_context.mongos_seeds(),
appName="retryableWriteTest",
event_listeners=[listener],
@ -492,7 +493,7 @@ class TestWriteConcernError(IntegrationTest):
@client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05)
def test_RetryableWriteError_error_label(self):
listener = OvertCommandListener()
client = rs_or_single_client(retryWrites=True, event_listeners=[listener])
client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener])
self.addCleanup(client.close)
# Ensure collection exists.
@ -551,7 +552,9 @@ class TestPoolPausedError(IntegrationTest):
def test_pool_paused_error_is_retryable(self):
cmap_listener = CMAPListener()
cmd_listener = OvertCommandListener()
client = rs_or_single_client(maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener])
client = self.rs_or_single_client(
maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]
)
self.addCleanup(client.close)
for _ in range(10):
cmap_listener.reset()
@ -613,7 +616,7 @@ class TestPoolPausedError(IntegrationTest):
self,
):
cmd_listener = InsertEventListener()
client = rs_or_single_client(retryWrites=True, event_listeners=[cmd_listener])
client = self.rs_or_single_client(retryWrites=True, event_listeners=[cmd_listener])
client.test.test.drop()
self.addCleanup(client.close)
cmd_listener.reset()
@ -650,7 +653,7 @@ class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest):
the first attempt fails before sending the command.
"""
listener = OvertCommandListener()
client = rs_or_single_client(retryWrites=True, event_listeners=[listener])
client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener])
self.addCleanup(client.close)
topology = client._topology
select_server = topology.select_server

View File

@ -25,7 +25,6 @@ sys.path[0:0] = [""]
from test import IntegrationTest, client_context, client_knobs, unittest
from test.utils import (
ServerAndTopologyEventListener,
rs_or_single_client,
server_name_to_type,
wait_until,
)
@ -279,7 +278,7 @@ class TestSdamMonitoring(IntegrationTest):
cls.knobs.enable()
cls.listener = ServerAndTopologyEventListener()
retry_writes = client_context.supports_transactions()
cls.test_client = rs_or_single_client(
cls.test_client = cls.unmanaged_rs_or_single_client(
event_listeners=[cls.listener], retryWrites=retry_writes
)
cls.coll = cls.test_client[cls.client.db.name].test

View File

@ -33,7 +33,6 @@ from test import IntegrationTest, client_context, unittest
from test.utils import (
EventListener,
FunctionCallRecorder,
rs_or_single_client,
wait_until,
)
from test.utils_selection_tests import (
@ -76,7 +75,9 @@ class TestCustomServerSelectorFunction(IntegrationTest):
# Initialize client with appropriate listeners.
listener = EventListener()
client = rs_or_single_client(server_selector=custom_selector, event_listeners=[listener])
client = self.rs_or_single_client(
server_selector=custom_selector, event_listeners=[listener]
)
self.addCleanup(client.close)
coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll
self.addCleanup(client.drop_database, "testdb")
@ -117,7 +118,7 @@ class TestCustomServerSelectorFunction(IntegrationTest):
selector = FunctionCallRecorder(lambda x: x)
# Client setup.
mongo_client = rs_or_single_client(server_selector=selector)
mongo_client = self.rs_or_single_client(server_selector=selector)
test_collection = mongo_client.testdb.test_collection
self.addCleanup(mongo_client.close)
self.addCleanup(mongo_client.drop_database, "testdb")

View File

@ -22,7 +22,6 @@ from test.utils import (
OvertCommandListener,
SpecTestCreator,
get_pool,
rs_client,
wait_until,
)
from test.utils_selection_tests import create_topology
@ -134,7 +133,7 @@ class TestProse(IntegrationTest):
listener = OvertCommandListener()
# PYTHON-2584: Use a large localThresholdMS to avoid the impact of
# varying RTTs.
client = rs_client(
client = self.rs_client(
client_context.mongos_seeds(),
appName="loadBalancingTest",
event_listeners=[listener],

View File

@ -36,7 +36,6 @@ from test import (
from test.utils import (
EventListener,
ExceptionCatchingThread,
rs_or_single_client,
wait_until,
)
@ -88,7 +87,7 @@ class TestSession(IntegrationTest):
super()._setup_class()
# Create a second client so we can make sure clients cannot share
# sessions.
cls.client2 = rs_or_single_client()
cls.client2 = cls.unmanaged_rs_or_single_client()
# Redact no commands, so we can test user-admin commands have "lsid".
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
@ -103,7 +102,7 @@ class TestSession(IntegrationTest):
def setUp(self):
self.listener = SessionTestListener()
self.session_checker_listener = SessionTestListener()
self.client = rs_or_single_client(
self.client = self.rs_or_single_client(
event_listeners=[self.listener, self.session_checker_listener]
)
self.addCleanup(self.client.close)
@ -200,7 +199,7 @@ class TestSession(IntegrationTest):
failures = 0
for _ in range(5):
listener = EventListener()
client = rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
client = self.rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
cursor = client.db.test.find({})
ops: List[Tuple[Callable, List[Any]]] = [
(client.db.test.find_one, [{"_id": 1}]),
@ -283,7 +282,7 @@ class TestSession(IntegrationTest):
def test_end_sessions(self):
# Use a new client so that the tearDown hook does not error.
listener = SessionTestListener()
client = rs_or_single_client(event_listeners=[listener])
client = self.rs_or_single_client(event_listeners=[listener])
# Start many sessions.
sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)]
for s in sessions:
@ -787,8 +786,7 @@ class TestSession(IntegrationTest):
def test_unacknowledged_writes(self):
# Ensure the collection exists.
self.client.pymongo_test.test_unacked_writes.insert_one({})
client = rs_or_single_client(w=0, event_listeners=[self.listener])
self.addCleanup(client.close)
client = self.rs_or_single_client(w=0, event_listeners=[self.listener])
db = client.pymongo_test
coll = db.test_unacked_writes
ops: list = [
@ -836,7 +834,7 @@ class TestCausalConsistency(UnitTest):
@classmethod
def _setup_class(cls):
cls.listener = SessionTestListener()
cls.client = rs_or_single_client(event_listeners=[cls.listener])
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
@classmethod
def _tearDown_class(cls):
@ -1137,8 +1135,7 @@ class TestClusterTime(IntegrationTest):
def test_cluster_time(self):
listener = SessionTestListener()
# Prevent heartbeats from updating $clusterTime between operations.
client = rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999)
self.addCleanup(client.close)
client = self.rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999)
collection = client.pymongo_test.collection
# Prepare for tests of find() and aggregate().
collection.insert_many([{} for _ in range(10)])

View File

@ -21,7 +21,7 @@ from typing import Any
sys.path[0:0] = [""]
from test import client_knobs, unittest
from test import PyMongoTestCase, client_knobs, unittest
from test.utils import FunctionCallRecorder, wait_until
import pymongo
@ -86,7 +86,7 @@ class SrvPollingKnobs:
self.disable()
class TestSrvPolling(unittest.TestCase):
class TestSrvPolling(PyMongoTestCase):
BASE_SRV_RESPONSE = [
("localhost.test.build.10gen.cc", 27017),
("localhost.test.build.10gen.cc", 27018),
@ -167,7 +167,7 @@ class TestSrvPolling(unittest.TestCase):
# Patch timeouts to ensure short test running times.
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = MongoClient(self.CONNECTION_STRING)
client = self.simple_client(self.CONNECTION_STRING)
self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client)
# Patch list of hosts returned by DNS query.
with SrvPollingKnobs(
@ -231,7 +231,7 @@ class TestSrvPolling(unittest.TestCase):
count_resolver_calls=True,
):
# Client uses unpatched method to get initial nodelist
client = MongoClient(self.CONNECTION_STRING)
client = self.simple_client(self.CONNECTION_STRING)
# Invalid DNS resolver response should not change nodelist.
self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client)
@ -264,8 +264,7 @@ class TestSrvPolling(unittest.TestCase):
return response
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=0)
self.addCleanup(client.close)
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=0)
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
self.assert_nodelist_change(response, client)
@ -279,8 +278,7 @@ class TestSrvPolling(unittest.TestCase):
return response
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=2)
self.addCleanup(client.close)
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2)
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
self.assert_nodelist_change(response, client)
@ -295,8 +293,7 @@ class TestSrvPolling(unittest.TestCase):
return response
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=2)
self.addCleanup(client.close)
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2)
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
sleep(2 * common.MIN_SRV_RESCAN_INTERVAL)
final_topology = set(client.topology_description.server_descriptions())
@ -305,8 +302,7 @@ class TestSrvPolling(unittest.TestCase):
def test_does_not_flipflop(self):
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=1)
self.addCleanup(client.close)
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=1)
old = set(client.topology_description.server_descriptions())
sleep(4 * WAIT_TIME)
new = set(client.topology_description.server_descriptions())
@ -323,7 +319,7 @@ class TestSrvPolling(unittest.TestCase):
return response
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = MongoClient(
client = self.simple_client(
"mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname"
)
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
@ -340,7 +336,7 @@ class TestSrvPolling(unittest.TestCase):
min_srv_rescan_interval=WAIT_TIME,
nodelist_callback=resolver_response,
):
client = MongoClient(self.CONNECTION_STRING)
client = self.simple_client(self.CONNECTION_STRING)
self.assertRaises(
AssertionError, self.assert_nodelist_change, modified, client, timeout=WAIT_TIME / 2
)

View File

@ -24,6 +24,7 @@ sys.path[0:0] = [""]
from test import (
HAVE_IPADDRESS,
IntegrationTest,
PyMongoTestCase,
SkipTest,
client_context,
connected,
@ -82,45 +83,45 @@ MONGODB_X509_USERNAME = "C=US,ST=New York,L=New York City,O=MDB,OU=Drivers,CN=cl
# use 'localhost' for the hostname of all hosts.
class TestClientSSL(unittest.TestCase):
class TestClientSSL(PyMongoTestCase):
@unittest.skipIf(HAVE_SSL, "The ssl module is available, can't test what happens without it.")
def test_no_ssl_module(self):
# Explicit
self.assertRaises(ConfigurationError, MongoClient, ssl=True)
self.assertRaises(ConfigurationError, self.simple_client, ssl=True)
# Implied
self.assertRaises(ConfigurationError, MongoClient, tlsCertificateKeyFile=CLIENT_PEM)
self.assertRaises(ConfigurationError, self.simple_client, tlsCertificateKeyFile=CLIENT_PEM)
@unittest.skipUnless(HAVE_SSL, "The ssl module is not available.")
@ignore_deprecations
def test_config_ssl(self):
# Tests various ssl configurations
self.assertRaises(ValueError, MongoClient, ssl="foo")
self.assertRaises(ValueError, self.simple_client, ssl="foo")
self.assertRaises(
ConfigurationError, MongoClient, tls=False, tlsCertificateKeyFile=CLIENT_PEM
ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM
)
self.assertRaises(TypeError, MongoClient, ssl=0)
self.assertRaises(TypeError, MongoClient, ssl=5.5)
self.assertRaises(TypeError, MongoClient, ssl=[])
self.assertRaises(TypeError, self.simple_client, ssl=0)
self.assertRaises(TypeError, self.simple_client, ssl=5.5)
self.assertRaises(TypeError, self.simple_client, ssl=[])
self.assertRaises(IOError, MongoClient, tlsCertificateKeyFile="NoSuchFile")
self.assertRaises(TypeError, MongoClient, tlsCertificateKeyFile=True)
self.assertRaises(TypeError, MongoClient, tlsCertificateKeyFile=[])
self.assertRaises(IOError, self.simple_client, tlsCertificateKeyFile="NoSuchFile")
self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=True)
self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=[])
# Test invalid combinations
self.assertRaises(
ConfigurationError, MongoClient, tls=False, tlsCertificateKeyFile=CLIENT_PEM
ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM
)
self.assertRaises(ConfigurationError, MongoClient, tls=False, tlsCAFile=CA_PEM)
self.assertRaises(ConfigurationError, MongoClient, tls=False, tlsCRLFile=CRL_PEM)
self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCAFile=CA_PEM)
self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCRLFile=CRL_PEM)
self.assertRaises(
ConfigurationError, MongoClient, tls=False, tlsAllowInvalidCertificates=False
ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidCertificates=False
)
self.assertRaises(
ConfigurationError, MongoClient, tls=False, tlsAllowInvalidHostnames=False
ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidHostnames=False
)
self.assertRaises(
ConfigurationError, MongoClient, tls=False, tlsDisableOCSPEndpointCheck=False
ConfigurationError, self.simple_client, tls=False, tlsDisableOCSPEndpointCheck=False
)
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
@ -174,7 +175,7 @@ class TestSSL(IntegrationTest):
if not hasattr(ssl, "SSLContext") and not _ssl.IS_PYOPENSSL:
self.assertRaises(
ConfigurationError,
MongoClient,
self.simple_client,
"localhost",
ssl=True,
tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM,
@ -184,7 +185,7 @@ class TestSSL(IntegrationTest):
)
else:
connected(
MongoClient(
self.simple_client(
"localhost",
ssl=True,
tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM,
@ -201,7 +202,7 @@ class TestSSL(IntegrationTest):
"&tlsCAFile=%s&serverSelectionTimeoutMS=5000"
)
connected(
MongoClient(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type]
self.simple_client(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type]
)
@client_context.require_tlsCertificateKeyFile
@ -215,7 +216,7 @@ class TestSSL(IntegrationTest):
#
# test that setting tlsCertificateKeyFile causes ssl to be set to True
client = MongoClient(
client = self.simple_client(
client_context.host,
client_context.port,
tlsAllowInvalidCertificates=True,
@ -223,7 +224,7 @@ class TestSSL(IntegrationTest):
)
response = client.admin.command(HelloCompat.LEGACY_CMD)
if "setName" in response:
client = MongoClient(
client = self.simple_client(
client_context.pair,
replicaSet=response["setName"],
w=len(response["hosts"]),
@ -242,7 +243,7 @@ class TestSSL(IntegrationTest):
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
#
client = MongoClient(
client = self.simple_client(
"localhost",
ssl=True,
tlsCertificateKeyFile=CLIENT_PEM,
@ -257,7 +258,7 @@ class TestSSL(IntegrationTest):
"Cannot validate hostname in the certificate"
)
client = MongoClient(
client = self.simple_client(
"localhost",
replicaSet=response["setName"],
w=len(response["hosts"]),
@ -270,7 +271,7 @@ class TestSSL(IntegrationTest):
self.assertClientWorks(client)
if HAVE_IPADDRESS:
client = MongoClient(
client = self.simple_client(
"127.0.0.1",
ssl=True,
tlsCertificateKeyFile=CLIENT_PEM,
@ -292,7 +293,7 @@ class TestSSL(IntegrationTest):
"mongodb://localhost/?ssl=true&tlsCertificateKeyFile=%s&tlsAllowInvalidCertificates"
"=%s&tlsCAFile=%s&tlsAllowInvalidHostnames=false"
)
client = MongoClient(uri_fmt % (CLIENT_PEM, "true", CA_PEM))
client = self.simple_client(uri_fmt % (CLIENT_PEM, "true", CA_PEM))
self.assertClientWorks(client)
@client_context.require_tlsCertificateKeyFile
@ -316,7 +317,7 @@ class TestSSL(IntegrationTest):
with self.assertRaises(ConnectionFailure):
connected(
MongoClient(
self.simple_client(
"server",
ssl=True,
tlsCertificateKeyFile=CLIENT_PEM,
@ -328,7 +329,7 @@ class TestSSL(IntegrationTest):
)
connected(
MongoClient(
self.simple_client(
"server",
ssl=True,
tlsCertificateKeyFile=CLIENT_PEM,
@ -343,7 +344,7 @@ class TestSSL(IntegrationTest):
if "setName" in response:
with self.assertRaises(ConnectionFailure):
connected(
MongoClient(
self.simple_client(
"server",
replicaSet=response["setName"],
ssl=True,
@ -356,7 +357,7 @@ class TestSSL(IntegrationTest):
)
connected(
MongoClient(
self.simple_client(
"server",
replicaSet=response["setName"],
ssl=True,
@ -375,7 +376,7 @@ class TestSSL(IntegrationTest):
if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or _ssl.IS_PYOPENSSL:
self.assertRaises(
ConfigurationError,
MongoClient,
self.simple_client,
"localhost",
ssl=True,
tlsCAFile=CA_PEM,
@ -384,7 +385,7 @@ class TestSSL(IntegrationTest):
)
else:
connected(
MongoClient(
self.simple_client(
"localhost",
ssl=True,
tlsCAFile=CA_PEM,
@ -395,7 +396,7 @@ class TestSSL(IntegrationTest):
with self.assertRaises(ConnectionFailure):
connected(
MongoClient(
self.simple_client(
"localhost",
ssl=True,
tlsCAFile=CA_PEM,
@ -406,7 +407,7 @@ class TestSSL(IntegrationTest):
)
uri_fmt = "mongodb://localhost/?ssl=true&tlsCAFile=%s&serverSelectionTimeoutMS=1000"
connected(MongoClient(uri_fmt % (CA_PEM,), **self.credentials)) # type: ignore
connected(self.simple_client(uri_fmt % (CA_PEM,), **self.credentials)) # type: ignore
uri_fmt = (
"mongodb://localhost/?ssl=true&tlsCRLFile=%s"
@ -414,7 +415,7 @@ class TestSSL(IntegrationTest):
)
with self.assertRaises(ConnectionFailure):
connected(
MongoClient(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type]
self.simple_client(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type]
)
@client_context.require_tlsCertificateKeyFile
@ -431,12 +432,14 @@ class TestSSL(IntegrationTest):
with self.assertRaises(ConnectionFailure):
# Server cert is verified but hostname matching fails
connected(
MongoClient("server", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials) # type: ignore[arg-type]
self.simple_client(
"server", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials
) # type: ignore[arg-type]
)
# Server cert is verified. Disable hostname matching.
connected(
MongoClient(
self.simple_client(
"server",
ssl=True,
tlsAllowInvalidHostnames=True,
@ -447,12 +450,14 @@ class TestSSL(IntegrationTest):
# Server cert and hostname are verified.
connected(
MongoClient("localhost", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials) # type: ignore[arg-type]
self.simple_client(
"localhost", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials
) # type: ignore[arg-type]
)
# Server cert and hostname are verified.
connected(
MongoClient(
self.simple_client(
"mongodb://localhost/?ssl=true&serverSelectionTimeoutMS=1000",
**self.credentials, # type: ignore[arg-type]
)
@ -472,7 +477,7 @@ class TestSSL(IntegrationTest):
ssl_support.HAVE_WINCERTSTORE = False
try:
with self.assertRaises(ConfigurationError):
MongoClient("mongodb://localhost/?ssl=true")
self.simple_client("mongodb://localhost/?ssl=true")
finally:
ssl_support.HAVE_CERTIFI = have_certifi
ssl_support.HAVE_WINCERTSTORE = have_wincertstore
@ -536,7 +541,7 @@ class TestSSL(IntegrationTest):
],
)
noauth = MongoClient(
noauth = self.simple_client(
client_context.pair,
ssl=True,
tlsAllowInvalidCertificates=True,
@ -548,7 +553,7 @@ class TestSSL(IntegrationTest):
noauth.pymongo_test.test.find_one()
listener = EventListener()
auth = MongoClient(
auth = self.simple_client(
client_context.pair,
authMechanism="MONGODB-X509",
ssl=True,
@ -572,7 +577,7 @@ class TestSSL(IntegrationTest):
host,
port,
)
client = MongoClient(
client = self.simple_client(
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
)
self.addCleanup(client.close)
@ -580,7 +585,7 @@ class TestSSL(IntegrationTest):
client.pymongo_test.test.find_one()
uri = "mongodb://%s:%d/?authMechanism=MONGODB-X509" % (host, port)
client = MongoClient(
client = self.simple_client(
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
)
self.addCleanup(client.close)
@ -593,7 +598,7 @@ class TestSSL(IntegrationTest):
port,
)
bad_client = MongoClient(
bad_client = self.simple_client(
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
)
self.addCleanup(bad_client.close)
@ -601,7 +606,7 @@ class TestSSL(IntegrationTest):
with self.assertRaises(OperationFailure):
bad_client.pymongo_test.test.find_one()
bad_client = MongoClient(
bad_client = self.simple_client(
client_context.pair,
username="not the username",
authMechanism="MONGODB-X509",
@ -622,7 +627,7 @@ class TestSSL(IntegrationTest):
)
try:
connected(
MongoClient(
self.simple_client(
uri,
ssl=True,
tlsAllowInvalidCertificates=True,
@ -648,7 +653,7 @@ class TestSSL(IntegrationTest):
self.addCleanup(remove, temp_ca_bundle)
# Add the CA cert file to the bundle.
cat_files(temp_ca_bundle, CA_BUNDLE_PEM, CA_PEM)
with MongoClient(
with self.simple_client(
"localhost", tls=True, tlsCertificateKeyFile=CLIENT_PEM, tlsCAFile=temp_ca_bundle
) as client:
self.assertTrue(client.admin.command("ping"))

View File

@ -24,8 +24,6 @@ from test import IntegrationTest, client_context, unittest
from test.utils import (
HeartbeatEventListener,
ServerEventListener,
rs_or_single_client,
single_client,
wait_until,
)
@ -38,7 +36,7 @@ class TestStreamingProtocol(IntegrationTest):
def test_failCommand_streaming(self):
listener = ServerEventListener()
hb_listener = HeartbeatEventListener()
client = rs_or_single_client(
client = self.rs_or_single_client(
event_listeners=[listener, hb_listener],
heartbeatFrequencyMS=500,
appName="failingHeartbeatTest",
@ -107,7 +105,7 @@ class TestStreamingProtocol(IntegrationTest):
},
}
with self.fail_point(delay_hello):
client = rs_or_single_client(
client = self.rs_or_single_client(
event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name
)
self.addCleanup(client.close)
@ -155,7 +153,7 @@ class TestStreamingProtocol(IntegrationTest):
}
with self.fail_point(fail_hello):
start = time.time()
client = single_client(
client = self.single_client(
appName="SDAMMinHeartbeatFrequencyTest", serverSelectionTimeoutMS=5000
)
self.addCleanup(client.close)
@ -180,7 +178,7 @@ class TestStreamingProtocol(IntegrationTest):
@client_context.require_failCommand_appName
def test_heartbeat_awaited_flag(self):
hb_listener = HeartbeatEventListener()
client = single_client(
client = self.single_client(
event_listeners=[hb_listener],
heartbeatFrequencyMS=500,
appName="heartbeatEventAwaitedFlag",

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import sys
from io import BytesIO
from test.utils_spec_runner import SpecRunner
from gridfs.synchronous.grid_file import GridFS, GridFSBucket
@ -25,8 +26,6 @@ sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test.utils import (
OvertCommandListener,
rs_client,
single_client,
wait_until,
)
from typing import List
@ -59,7 +58,18 @@ _IS_SYNC = True
UNPIN_TEST_MAX_ATTEMPTS = 50
class TestTransactions(IntegrationTest):
class TransactionsBase(SpecRunner):
def maybe_skip_scenario(self, test):
super().maybe_skip_scenario(test)
if (
"secondary" in self.id()
and not client_context.is_mongos
and not client_context.has_secondaries
):
raise unittest.SkipTest("No secondaries")
class TestTransactions(TransactionsBase):
RUN_ON_SERVERLESS = True
@client_context.require_transactions
@ -92,8 +102,7 @@ class TestTransactions(IntegrationTest):
@client_context.require_transactions
def test_transaction_write_concern_override(self):
"""Test txn overrides Client/Database/Collection write_concern."""
client = rs_client(w=0)
self.addCleanup(client.close)
client = self.rs_client(w=0)
db = client.test
coll = db.test
coll.insert_one({})
@ -146,12 +155,11 @@ class TestTransactions(IntegrationTest):
def test_unpin_for_next_transaction(self):
# Increase localThresholdMS and wait until both nodes are discovered
# to avoid false positives.
client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000)
client = self.rs_client(client_context.mongos_seeds(), localThresholdMS=1000)
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
coll = client.test.test
# Create the collection.
coll.insert_one({})
self.addCleanup(client.close)
with client.start_session() as s:
# Session is pinned to Mongos.
with s.start_transaction():
@ -174,12 +182,11 @@ class TestTransactions(IntegrationTest):
def test_unpin_for_non_transaction_operation(self):
# Increase localThresholdMS and wait until both nodes are discovered
# to avoid false positives.
client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000)
client = self.rs_client(client_context.mongos_seeds(), localThresholdMS=1000)
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
coll = client.test.test
# Create the collection.
coll.insert_one({})
self.addCleanup(client.close)
with client.start_session() as s:
# Session is pinned to Mongos.
with s.start_transaction():
@ -303,11 +310,10 @@ class TestTransactions(IntegrationTest):
# Start a transaction with a batch of operations that needs to be
# split.
listener = OvertCommandListener()
client = rs_client(event_listeners=[listener])
client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test
coll.delete_many({})
listener.reset()
self.addCleanup(client.close)
self.addCleanup(coll.drop)
large_str = "\0" * (1 * 1024 * 1024)
ops: List[InsertOne[RawBSONDocument]] = [
@ -332,8 +338,7 @@ class TestTransactions(IntegrationTest):
@client_context.require_transactions
def test_transaction_direct_connection(self):
client = single_client()
self.addCleanup(client.close)
client = self.single_client()
coll = client.pymongo_test.test
# Make sure the collection exists.
@ -389,14 +394,14 @@ class PatchSessionTimeout:
client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.real_timeout
class TestTransactionsConvenientAPI(IntegrationTest):
class TestTransactionsConvenientAPI(TransactionsBase):
@classmethod
def _setup_class(cls):
super()._setup_class()
cls.mongos_clients = []
if client_context.supports_transactions():
for address in client_context.mongoses:
cls.mongos_clients.append(single_client("{}:{}".format(*address)))
cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address)))
@classmethod
def _tearDown_class(cls):
@ -446,8 +451,7 @@ class TestTransactionsConvenientAPI(IntegrationTest):
@client_context.require_transactions
def test_callback_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = rs_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test
def callback(session):
@ -475,8 +479,7 @@ class TestTransactionsConvenientAPI(IntegrationTest):
@client_context.require_transactions
def test_callback_not_retried_after_commit_timeout(self):
listener = OvertCommandListener()
client = rs_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test
def callback(session):
@ -508,8 +511,7 @@ class TestTransactionsConvenientAPI(IntegrationTest):
@client_context.require_transactions
def test_commit_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = rs_client(event_listeners=[listener])
self.addCleanup(client.close)
client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test
def callback(session):

View File

@ -68,8 +68,7 @@ except ImportError:
sys.path[0:0] = [""]
from test import IntegrationTest, client_context
from test.utils import rs_or_single_client
from test import IntegrationTest, PyMongoTestCase, client_context
from bson import CodecOptions, decode, decode_all, decode_file_iter, decode_iter, encode
from bson.raw_bson import RawBSONDocument
@ -194,7 +193,7 @@ class TestPymongo(IntegrationTest):
value.items()
def test_default_document_type(self) -> None:
client = rs_or_single_client()
client = self.rs_or_single_client()
self.addCleanup(client.close)
coll = client.test.test
doc = {"my": "doc"}
@ -366,7 +365,7 @@ class TestDecode(unittest.TestCase):
doc["a"] = 2
class TestDocumentType(unittest.TestCase):
class TestDocumentType(PyMongoTestCase):
@only_type_check
def test_default(self) -> None:
client: MongoClient = MongoClient()
@ -480,7 +479,7 @@ class TestDocumentType(unittest.TestCase):
def test_typeddict_find_notrequired(self):
if NotRequired is None or ImplicitMovie is None:
raise unittest.SkipTest("Python 3.11+ is required to use NotRequired.")
client: MongoClient[ImplicitMovie] = rs_or_single_client()
client: MongoClient[ImplicitMovie] = self.rs_or_single_client()
coll = client.test.test
coll.insert_one(ImplicitMovie(name="THX-1138", year=1971))
out = coll.find_one({})

View File

@ -20,7 +20,7 @@ sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test.unified_format import generate_test_classes
from test.utils import OvertCommandListener, rs_or_single_client
from test.utils import OvertCommandListener
from pymongo.server_api import ServerApi, ServerApiVersion
from pymongo.synchronous.mongo_client import MongoClient
@ -77,7 +77,7 @@ class TestServerApi(IntegrationTest):
@client_context.require_version_min(4, 7)
def test_command_options(self):
listener = OvertCommandListener()
client = rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener])
client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener])
self.addCleanup(client.close)
coll = client.test.test
coll.insert_many([{} for _ in range(100)])
@ -90,7 +90,7 @@ class TestServerApi(IntegrationTest):
@client_context.require_transactions
def test_command_options_txn(self):
listener = OvertCommandListener()
client = rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener])
client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener])
self.addCleanup(client.close)
coll = client.test.test
coll.insert_many([{} for _ in range(100)])

View File

@ -55,8 +55,6 @@ from test.utils import (
parse_collection_options,
parse_spec_options,
prepare_spec_arguments,
rs_or_single_client,
single_client,
snake_to_camel,
wait_until,
)
@ -574,7 +572,7 @@ class EntityMapUtil:
)
if uri:
kwargs["h"] = uri
client = rs_or_single_client(**kwargs)
client = self.test.rs_or_single_client(**kwargs)
self[spec["id"]] = client
self.test.addCleanup(client.close)
return
@ -1115,7 +1113,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
and not client_context.serverless
):
for address in client_context.mongoses:
cls.mongos_clients.append(single_client("{}:{}".format(*address)))
cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address)))
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(
@ -1646,7 +1644,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
)
)
client = single_client("{}:{}".format(*session._pinned_address))
client = self.single_client("{}:{}".format(*session._pinned_address))
self.addCleanup(client.close)
self.__set_fail_point(client=client, command_args=spec["failPoint"])

View File

@ -565,151 +565,6 @@ class SpecTestCreator:
setattr(self._test_class, new_test.__name__, new_test)
def _connection_string(h):
if h.startswith(("mongodb://", "mongodb+srv://")):
return h
return f"mongodb://{h!s}"
def _mongo_client(host, port, authenticate=True, directConnection=None, **kwargs):
"""Create a new client over SSL/TLS if necessary."""
host = host or client_context.host
port = port or client_context.port
client_options: dict = client_context.default_client_options.copy()
if client_context.replica_set_name and not directConnection:
client_options["replicaSet"] = client_context.replica_set_name
if directConnection is not None:
client_options["directConnection"] = directConnection
client_options.update(kwargs)
uri = _connection_string(host)
auth_mech = kwargs.get("authMechanism", "")
if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
# Only add the default username or password if one is not provided.
res = parse_uri(uri)
if (
not res["username"]
and not res["password"]
and "username" not in client_options
and "password" not in client_options
):
client_options["username"] = db_user
client_options["password"] = db_pwd
return MongoClient(uri, port, **client_options)
async def _async_mongo_client(host, port, authenticate=True, directConnection=None, **kwargs):
"""Create a new client over SSL/TLS if necessary."""
host = host or await async_client_context.host
port = port or await async_client_context.port
client_options: dict = async_client_context.default_client_options.copy()
if async_client_context.replica_set_name and not directConnection:
client_options["replicaSet"] = async_client_context.replica_set_name
if directConnection is not None:
client_options["directConnection"] = directConnection
client_options.update(kwargs)
uri = _connection_string(host)
auth_mech = kwargs.get("authMechanism", "")
if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
# Only add the default username or password if one is not provided.
res = parse_uri(uri)
if (
not res["username"]
and not res["password"]
and "username" not in client_options
and "password" not in client_options
):
client_options["username"] = db_user
client_options["password"] = db_pwd
client = AsyncMongoClient(uri, port, **client_options)
if client._options.connect:
await client.aconnect()
return client
def single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return _mongo_client(h, p, authenticate=False, directConnection=True, **kwargs)
def single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Make a direct connection, and authenticate if necessary."""
return _mongo_client(h, p, directConnection=True, **kwargs)
def rs_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Connect to the replica set. Don't authenticate."""
return _mongo_client(h, p, authenticate=False, **kwargs)
def rs_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Connect to the replica set and authenticate if necessary."""
return _mongo_client(h, p, **kwargs)
def rs_or_single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Connect to the replica set if there is one, otherwise the standalone.
Like rs_or_single_client, but does not authenticate.
"""
return _mongo_client(h, p, authenticate=False, **kwargs)
def rs_or_single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[Any]:
"""Connect to the replica set if there is one, otherwise the standalone.
Authenticates if necessary.
"""
return _mongo_client(h, p, **kwargs)
async def async_single_client_noauth(
h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return await _async_mongo_client(h, p, authenticate=False, directConnection=True, **kwargs)
async def async_single_client(
h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Make a direct connection, and authenticate if necessary."""
return await _async_mongo_client(h, p, directConnection=True, **kwargs)
async def async_rs_client_noauth(
h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Connect to the replica set. Don't authenticate."""
return await _async_mongo_client(h, p, authenticate=False, **kwargs)
async def async_rs_client(h: Any = None, p: Any = None, **kwargs: Any) -> AsyncMongoClient[dict]:
"""Connect to the replica set and authenticate if necessary."""
return await _async_mongo_client(h, p, **kwargs)
async def async_rs_or_single_client_noauth(
h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[dict]:
"""Connect to the replica set if there is one, otherwise the standalone.
Like rs_or_single_client, but does not authenticate.
"""
return await _async_mongo_client(h, p, authenticate=False, **kwargs)
async def async_rs_or_single_client(
h: Any = None, p: Any = None, **kwargs: Any
) -> AsyncMongoClient[Any]:
"""Connect to the replica set if there is one, otherwise the standalone.
Authenticates if necessary.
"""
return await _async_mongo_client(h, p, **kwargs)
def ensure_all_connected(client: MongoClient) -> None:
"""Ensure that the client's connection pool has socket connections to all
members of a replica set. Raises ConfigurationError when called with a
@ -1108,20 +963,6 @@ def is_greenthread_patched():
return gevent_monkey_patched() or eventlet_monkey_patched()
def disable_replication(client):
"""Disable replication on all secondaries."""
for host, port in client.secondaries:
secondary = single_client(host, port)
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn")
def enable_replication(client):
"""Enable replication on all secondaries."""
for host, port in client.secondaries:
secondary = single_client(host, port)
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off")
class ExceptionCatchingThread(threading.Thread):
"""A thread that stores any exception encountered from run()."""

View File

@ -29,7 +29,6 @@ from test.utils import (
camel_to_snake_args,
parse_spec_options,
prepare_spec_arguments,
rs_client,
)
from typing import List
@ -101,6 +100,8 @@ class SpecRunner(IntegrationTest):
@classmethod
def _tearDown_class(cls):
cls.knobs.disable()
for client in cls.mongos_clients:
client.close()
super()._tearDown_class()
def setUp(self):
@ -524,7 +525,7 @@ class SpecRunner(IntegrationTest):
host = client_context.MULTI_MONGOS_LB_URI
elif client_context.is_mongos:
host = client_context.mongos_seeds()
client = rs_client(
client = self.rs_client(
h=host, event_listeners=[listener, pool_listener, server_listener], **client_options
)
self.scenario_client = client