Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
54fd7b6104
@ -2081,10 +2081,10 @@ axes:
|
||||
batchtime: 10080 # 7 days
|
||||
- id: rhel8
|
||||
display_name: "RHEL 8.x"
|
||||
run_on: rhel87-small
|
||||
run_on: rhel8.8-small
|
||||
batchtime: 10080 # 7 days
|
||||
- id: rhel92-fips
|
||||
display_name: "RHEL 9.2 FIPS"
|
||||
- id: rhel9-fips
|
||||
display_name: "RHEL 9 FIPS"
|
||||
run_on: rhel92-fips
|
||||
batchtime: 10080 # 7 days
|
||||
- id: ubuntu-22.04
|
||||
@ -2095,24 +2095,24 @@ axes:
|
||||
display_name: "Ubuntu 20.04"
|
||||
run_on: ubuntu2004-small
|
||||
batchtime: 10080 # 7 days
|
||||
- id: rhel83-zseries
|
||||
display_name: "RHEL 8.3 (zSeries)"
|
||||
run_on: rhel83-zseries-small
|
||||
- id: rhel8-zseries
|
||||
display_name: "RHEL 8 (zSeries)"
|
||||
run_on: rhel8-zseries-small
|
||||
batchtime: 10080 # 7 days
|
||||
variables:
|
||||
SKIP_HATCH: true
|
||||
- id: rhel81-power8
|
||||
display_name: "RHEL 8.1 (POWER8)"
|
||||
run_on: rhel81-power8-small
|
||||
- id: rhel8-power8
|
||||
display_name: "RHEL 8 (POWER8)"
|
||||
run_on: rhel8-power-small
|
||||
batchtime: 10080 # 7 days
|
||||
variables:
|
||||
SKIP_HATCH: true
|
||||
- id: rhel82-arm64
|
||||
display_name: "RHEL 8.2 (ARM64)"
|
||||
- id: rhel8-arm64
|
||||
display_name: "RHEL 8 (ARM64)"
|
||||
run_on: rhel82-arm64-small
|
||||
batchtime: 10080 # 7 days
|
||||
variables:
|
||||
- id: windows-64-vsMulti-small
|
||||
- id: windows
|
||||
display_name: "Windows 64"
|
||||
run_on: windows-64-vsMulti-small
|
||||
batchtime: 10080 # 7 days
|
||||
@ -2470,7 +2470,7 @@ buildvariants:
|
||||
- matrix_name: "tests-fips"
|
||||
matrix_spec:
|
||||
platform:
|
||||
- rhel92-fips
|
||||
- rhel9-fips
|
||||
auth: "auth"
|
||||
ssl: "ssl"
|
||||
display_name: "${platform} ${auth} ${ssl}"
|
||||
@ -2547,9 +2547,9 @@ buildvariants:
|
||||
- matrix_name: "test-different-cpu-architectures"
|
||||
matrix_spec:
|
||||
platform:
|
||||
- rhel83-zseries # Added in 5.0.8 (SERVER-44074)
|
||||
- rhel81-power8 # Added in 4.2.7 (SERVER-44072)
|
||||
- rhel82-arm64 # Added in 4.4.2 (SERVER-48282)
|
||||
- rhel8-zseries # Added in 5.0.8 (SERVER-44074)
|
||||
- rhel8-power8 # Added in 4.2.7 (SERVER-44072)
|
||||
- rhel8-arm64 # Added in 4.4.2 (SERVER-48282)
|
||||
auth-ssl: "*"
|
||||
display_name: "${platform} ${auth-ssl}"
|
||||
tasks:
|
||||
@ -2606,7 +2606,7 @@ buildvariants:
|
||||
|
||||
- matrix_name: "tests-pyopenssl-windows"
|
||||
matrix_spec:
|
||||
platform: windows-64-vsMulti-small
|
||||
platform: windows
|
||||
python-version-windows: "*"
|
||||
auth: "auth"
|
||||
ssl: "ssl"
|
||||
@ -2698,7 +2698,7 @@ buildvariants:
|
||||
|
||||
- matrix_name: "tests-windows-python-version"
|
||||
matrix_spec:
|
||||
platform: windows-64-vsMulti-small
|
||||
platform: windows
|
||||
python-version-windows: "*"
|
||||
auth-ssl: "*"
|
||||
display_name: "${platform} ${python-version-windows} ${auth-ssl}"
|
||||
@ -2706,7 +2706,7 @@ buildvariants:
|
||||
|
||||
- matrix_name: "tests-windows-python-version-32-bit"
|
||||
matrix_spec:
|
||||
platform: windows-64-vsMulti-small
|
||||
platform: windows
|
||||
python-version-windows-32: "*"
|
||||
auth-ssl: "*"
|
||||
display_name: "${platform} ${python-version-windows-32} ${auth-ssl}"
|
||||
@ -2724,7 +2724,7 @@ buildvariants:
|
||||
|
||||
- matrix_name: "tests-windows-encryption"
|
||||
matrix_spec:
|
||||
platform: windows-64-vsMulti-small
|
||||
platform: windows
|
||||
python-version-windows: "*"
|
||||
auth-ssl: "*"
|
||||
encryption: "*"
|
||||
@ -2733,7 +2733,7 @@ buildvariants:
|
||||
rules:
|
||||
- if:
|
||||
encryption: ["encryption", "encryption_crypt_shared"]
|
||||
platform: windows-64-vsMulti-small
|
||||
platform: windows
|
||||
python-version-windows: "*"
|
||||
auth-ssl: "*"
|
||||
then:
|
||||
@ -2795,7 +2795,7 @@ buildvariants:
|
||||
|
||||
- matrix_name: "tests-windows-enterprise-auth"
|
||||
matrix_spec:
|
||||
platform: windows-64-vsMulti-small
|
||||
platform: windows
|
||||
python-version-windows: "*"
|
||||
auth: "auth"
|
||||
display_name: "Enterprise ${auth} ${platform} ${python-version-windows}"
|
||||
@ -2907,7 +2907,7 @@ buildvariants:
|
||||
|
||||
- matrix_name: "ocsp-test-windows"
|
||||
matrix_spec:
|
||||
platform: windows-64-vsMulti-small
|
||||
platform: windows
|
||||
python-version-windows: ["3.8", "3.10"]
|
||||
mongodb-version: ["4.4", "5.0", "6.0", "7.0", "8.0", "latest"]
|
||||
auth: "noauth"
|
||||
@ -2932,7 +2932,7 @@ buildvariants:
|
||||
|
||||
- matrix_name: "oidc-auth-test"
|
||||
matrix_spec:
|
||||
platform: [ rhel8, macos, windows-64-vsMulti-small ]
|
||||
platform: [ rhel8, macos, windows ]
|
||||
display_name: "OIDC Auth ${platform}"
|
||||
tasks:
|
||||
- name: testoidc_task_group
|
||||
@ -2981,7 +2981,7 @@ buildvariants:
|
||||
|
||||
- matrix_name: "aws-auth-test-windows"
|
||||
matrix_spec:
|
||||
platform: [windows-64-vsMulti-small]
|
||||
platform: [windows]
|
||||
python-version-windows: "*"
|
||||
display_name: "MONGODB-AWS Auth ${platform} ${python-version-windows}"
|
||||
tasks:
|
||||
|
||||
@ -163,13 +163,10 @@ hatch run lint:build-manual
|
||||
|
||||
## Documentation
|
||||
|
||||
To contribute to the [API
|
||||
documentation](https://pymongo.readthedocs.io/en/stable/) just make your
|
||||
changes to the inline documentation of the appropriate [source
|
||||
code](https://github.com/mongodb/mongo-python-driver) or [rst
|
||||
file](https://github.com/mongodb/mongo-python-driver/tree/master/doc) in
|
||||
a branch and submit a [pull
|
||||
request](https://help.github.com/articles/using-pull-requests). You
|
||||
To contribute to the [API documentation](https://pymongo.readthedocs.io/en/stable/) just make your
|
||||
changes to the inline documentation of the appropriate [source code](https://github.com/mongodb/mongo-python-driver) or
|
||||
[rst file](https://github.com/mongodb/mongo-python-driver/tree/master/doc) in
|
||||
a branch and submit a [pull request](https://help.github.com/articles/using-pull-requests). You
|
||||
might also use the GitHub
|
||||
[Edit](https://github.com/blog/844-forking-with-the-edit-button) button.
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
[](https://pypi.org/project/pymongo)
|
||||
[](https://pypi.org/project/pymongo)
|
||||
[](https://pepy.tech/project/pymongo)
|
||||
[](http://pymongo.readthedocs.io/en/stable/?badge=stable)
|
||||
[](http://pymongo.readthedocs.io/en/stable/api?badge=stable)
|
||||
|
||||
## About
|
||||
|
||||
|
||||
@ -306,7 +306,7 @@ static PyObject* datetime_from_millis(long long millis) {
|
||||
if (evalue) {
|
||||
PyObject* err_msg = PyObject_Str(evalue);
|
||||
if (err_msg) {
|
||||
PyObject* appendage = PyUnicode_FromString(" (Consider Using CodecOptions(datetime_conversion=DATETIME_AUTO) or MongoClient(datetime_conversion='DATETIME_AUTO')). See: https://pymongo.readthedocs.io/en/stable/examples/datetimes.html#handling-out-of-range-datetimes");
|
||||
PyObject* appendage = PyUnicode_FromString(" (Consider Using CodecOptions(datetime_conversion=DATETIME_AUTO) or MongoClient(datetime_conversion='DATETIME_AUTO')). See: https://www.mongodb.com/docs/languages/python/pymongo-driver/current/data-formats/dates-and-times/#handling-out-of-range-datetimes");
|
||||
if (appendage) {
|
||||
PyObject* msg = PyUnicode_Concat(err_msg, appendage);
|
||||
if (msg) {
|
||||
|
||||
@ -31,7 +31,7 @@ EPOCH_NAIVE = EPOCH_AWARE.replace(tzinfo=None)
|
||||
_DATETIME_ERROR_SUGGESTION = (
|
||||
"(Consider Using CodecOptions(datetime_conversion=DATETIME_AUTO)"
|
||||
" or MongoClient(datetime_conversion='DATETIME_AUTO'))."
|
||||
" See: https://pymongo.readthedocs.io/en/stable/examples/datetimes.html#handling-out-of-range-datetimes"
|
||||
" See: https://www.mongodb.com/docs/languages/python/pymongo-driver/current/data-formats/dates-and-times/#handling-out-of-range-datetimes"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -93,6 +93,10 @@ Unavoidable breaking changes
|
||||
- Since we are now using ``hatch`` as our build backend, we no longer have a usable ``setup.py`` file
|
||||
and require installation using ``pip``. Attempts to invoke the ``setup.py`` file will raise an exception.
|
||||
Additionally, ``pip`` >= 21.3 is now required for editable installs.
|
||||
- We no longer support the ``srv`` extra, since ``dnspython`` is included as a dependency in PyMongo 4.7+.
|
||||
Instead of ``pip install pymongo[srv]``, use ``pip install pymongo``.
|
||||
- We no longer support the ``tls`` extra, which was only valid for Python 2.
|
||||
Instead of ``pip install pymongo[tls]``, use ``pip install pymongo``.
|
||||
|
||||
Issues Resolved
|
||||
...............
|
||||
|
||||
@ -1,6 +1,11 @@
|
||||
PyMongo |release| Documentation
|
||||
===============================
|
||||
|
||||
.. note:: The PyMongo documentation has been migrated to the
|
||||
`MongoDB Documentation site <https://www.mongodb.com/docs/languages/python/pymongo-driver/current>`_.
|
||||
As of PyMongo 4.10, the ReadTheDocs site will contain the detailed changelog and API docs, while the
|
||||
rest of the documentation will only appear on the MongoDB Documentation site.
|
||||
|
||||
Overview
|
||||
--------
|
||||
**PyMongo** is a Python distribution containing tools for working with
|
||||
@ -95,8 +100,6 @@ pull request.
|
||||
Changes
|
||||
-------
|
||||
See the :doc:`changelog` for a full list of changes to PyMongo.
|
||||
For older versions of the documentation please see the
|
||||
`archive list <http://api.mongodb.org/python/>`_.
|
||||
|
||||
About This Documentation
|
||||
------------------------
|
||||
|
||||
@ -29,6 +29,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from bson import CodecOptions, _convert_raw_document_lists_to_streams
|
||||
from pymongo import _csot
|
||||
from pymongo.asynchronous.cursor import _ConnectionManager
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
@ -77,6 +78,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
self._address = address
|
||||
self._batch_size = batch_size
|
||||
self._max_await_time_ms = max_await_time_ms
|
||||
self._timeout = self._collection.database.client.options.timeout
|
||||
self._session = session
|
||||
self._explicit_session = explicit_session
|
||||
self._killed = self._id == 0
|
||||
@ -385,6 +387,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
@_csot.apply
|
||||
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
|
||||
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ from typing import (
|
||||
from bson import RE_TYPE, _convert_raw_document_lists_to_streams
|
||||
from bson.code import Code
|
||||
from bson.son import SON
|
||||
from pymongo import helpers_shared
|
||||
from pymongo import _csot, helpers_shared
|
||||
from pymongo.asynchronous.helpers import anext
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.common import (
|
||||
@ -196,6 +196,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
self._explain = False
|
||||
self._comment = comment
|
||||
self._max_time_ms = max_time_ms
|
||||
self._timeout = self._collection.database.client.options.timeout
|
||||
self._max_await_time_ms: Optional[int] = None
|
||||
self._max: Optional[Union[dict[Any, Any], _Sort]] = max
|
||||
self._min: Optional[Union[dict[Any, Any], _Sort]] = min
|
||||
@ -1290,6 +1291,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
@_csot.apply
|
||||
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
|
||||
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
|
||||
|
||||
|
||||
@ -1193,7 +1193,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
),
|
||||
ResourceWarning,
|
||||
stacklevel=2,
|
||||
source=self,
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@ -227,8 +227,9 @@ class Topology:
|
||||
warnings.warn( # type: ignore[call-overload] # noqa: B028
|
||||
"AsyncMongoClient opened before fork. May not be entirely fork-safe, "
|
||||
"proceed with caution. See PyMongo's documentation for details: "
|
||||
"https://pymongo.readthedocs.io/en/stable/faq.html#"
|
||||
"is-pymongo-fork-safe",
|
||||
"https://www.mongodb.com/docs/languages/"
|
||||
"python/pymongo-driver/current/faq/"
|
||||
"#is-pymongo-fork-safe-",
|
||||
**kwargs,
|
||||
)
|
||||
async with self._lock:
|
||||
|
||||
@ -29,6 +29,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from bson import CodecOptions, _convert_raw_document_lists_to_streams
|
||||
from pymongo import _csot
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.message import (
|
||||
@ -77,6 +78,7 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
self._address = address
|
||||
self._batch_size = batch_size
|
||||
self._max_await_time_ms = max_await_time_ms
|
||||
self._timeout = self._collection.database.client.options.timeout
|
||||
self._session = session
|
||||
self._explicit_session = explicit_session
|
||||
self._killed = self._id == 0
|
||||
@ -385,6 +387,7 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
@_csot.apply
|
||||
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
|
||||
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ from typing import (
|
||||
from bson import RE_TYPE, _convert_raw_document_lists_to_streams
|
||||
from bson.code import Code
|
||||
from bson.son import SON
|
||||
from pymongo import helpers_shared
|
||||
from pymongo import _csot, helpers_shared
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.common import (
|
||||
validate_is_document_type,
|
||||
@ -196,6 +196,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
self._explain = False
|
||||
self._comment = comment
|
||||
self._max_time_ms = max_time_ms
|
||||
self._timeout = self._collection.database.client.options.timeout
|
||||
self._max_await_time_ms: Optional[int] = None
|
||||
self._max: Optional[Union[dict[Any, Any], _Sort]] = max
|
||||
self._min: Optional[Union[dict[Any, Any], _Sort]] = min
|
||||
@ -1288,6 +1289,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
@_csot.apply
|
||||
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
|
||||
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.
|
||||
|
||||
|
||||
@ -1193,7 +1193,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
),
|
||||
ResourceWarning,
|
||||
stacklevel=2,
|
||||
source=self,
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@ -227,8 +227,9 @@ class Topology:
|
||||
warnings.warn( # type: ignore[call-overload] # noqa: B028
|
||||
"MongoClient opened before fork. May not be entirely fork-safe, "
|
||||
"proceed with caution. See PyMongo's documentation for details: "
|
||||
"https://pymongo.readthedocs.io/en/stable/faq.html#"
|
||||
"is-pymongo-fork-safe",
|
||||
"https://www.mongodb.com/docs/languages/"
|
||||
"python/pymongo-driver/current/faq/"
|
||||
"#is-pymongo-fork-safe-",
|
||||
**kwargs,
|
||||
)
|
||||
with self._lock:
|
||||
|
||||
@ -42,7 +42,7 @@ classifiers = [
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://www.mongodb.org"
|
||||
Documentation = "https://pymongo.readthedocs.io"
|
||||
Documentation = "https://www.mongodb.com/docs/languages/python/pymongo-driver/current/"
|
||||
Source = "https://github.com/mongodb/mongo-python-driver"
|
||||
Tracker = "https://jira.mongodb.org/projects/PYTHON/issues"
|
||||
|
||||
@ -96,9 +96,6 @@ filterwarnings = [
|
||||
"module:please use dns.resolver.Resolver.resolve:DeprecationWarning",
|
||||
# https://github.com/dateutil/dateutil/issues/1314
|
||||
"module:datetime.datetime.utc:DeprecationWarning:dateutil",
|
||||
# TODO: Remove both of these in https://jira.mongodb.org/browse/PYTHON-4731
|
||||
"ignore:Unclosed AsyncMongoClient*",
|
||||
"ignore:Unclosed MongoClient*",
|
||||
]
|
||||
markers = [
|
||||
"auth_aws: tests that rely on pymongo-auth-aws",
|
||||
|
||||
189
test/__init__.py
189
test/__init__.py
@ -16,8 +16,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextlib
|
||||
import gc
|
||||
import multiprocessing
|
||||
import os
|
||||
@ -27,7 +25,6 @@ import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
@ -54,6 +51,8 @@ from test.helpers import (
|
||||
sanitize_reply,
|
||||
)
|
||||
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
@ -80,6 +79,12 @@ from pymongo.synchronous.mongo_client import MongoClient
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def _connection_string(h):
|
||||
if h.startswith(("mongodb://", "mongodb+srv://")):
|
||||
return h
|
||||
return f"mongodb://{h!s}"
|
||||
|
||||
|
||||
class ClientContext:
|
||||
client: MongoClient
|
||||
|
||||
@ -230,6 +235,9 @@ class ClientContext:
|
||||
if not self._check_user_provided():
|
||||
_create_user(self.client.admin, db_user, db_pwd)
|
||||
|
||||
if self.client:
|
||||
self.client.close()
|
||||
|
||||
self.client = self._connect(
|
||||
host,
|
||||
port,
|
||||
@ -256,6 +264,8 @@ class ClientContext:
|
||||
if "setName" in hello:
|
||||
self.replica_set_name = str(hello["setName"])
|
||||
self.is_rs = True
|
||||
if self.client:
|
||||
self.client.close()
|
||||
if self.auth_enabled:
|
||||
# It doesn't matter which member we use as the seed here.
|
||||
self.client = pymongo.MongoClient(
|
||||
@ -318,6 +328,7 @@ class ClientContext:
|
||||
hello = mongos_client.admin.command(HelloCompat.LEGACY_CMD)
|
||||
if hello.get("msg") == "isdbgrid":
|
||||
self.mongoses.append(next_address)
|
||||
mongos_client.close()
|
||||
|
||||
def init(self):
|
||||
with self.conn_lock:
|
||||
@ -537,12 +548,6 @@ class ClientContext:
|
||||
lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func
|
||||
)
|
||||
|
||||
def require_no_fips(self, func):
|
||||
"""Run a test only if the host does not have FIPS enabled."""
|
||||
return self._require(
|
||||
lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func
|
||||
)
|
||||
|
||||
def require_no_auth(self, func):
|
||||
"""Run a test only if the server is running without auth enabled."""
|
||||
return self._require(
|
||||
@ -930,6 +935,172 @@ class PyMongoTestCase(unittest.TestCase):
|
||||
self.fail(f"child timed out after {timeout}s (see traceback in logs): deadlock?")
|
||||
self.assertEqual(proc.exitcode, 0)
|
||||
|
||||
@classmethod
|
||||
def _unmanaged_async_mongo_client(
|
||||
cls, host, port, authenticate=True, directConnection=None, **kwargs
|
||||
):
|
||||
"""Create a new client over SSL/TLS if necessary."""
|
||||
host = host or client_context.host
|
||||
port = port or client_context.port
|
||||
client_options: dict = client_context.default_client_options.copy()
|
||||
if client_context.replica_set_name and not directConnection:
|
||||
client_options["replicaSet"] = client_context.replica_set_name
|
||||
if directConnection is not None:
|
||||
client_options["directConnection"] = directConnection
|
||||
client_options.update(kwargs)
|
||||
|
||||
uri = _connection_string(host)
|
||||
auth_mech = kwargs.get("authMechanism", "")
|
||||
if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
|
||||
# Only add the default username or password if one is not provided.
|
||||
res = parse_uri(uri)
|
||||
if (
|
||||
not res["username"]
|
||||
and not res["password"]
|
||||
and "username" not in client_options
|
||||
and "password" not in client_options
|
||||
):
|
||||
client_options["username"] = db_user
|
||||
client_options["password"] = db_pwd
|
||||
client = MongoClient(uri, port, **client_options)
|
||||
if client._options.connect:
|
||||
client._connect()
|
||||
return client
|
||||
|
||||
def _async_mongo_client(self, host, port, authenticate=True, directConnection=None, **kwargs):
|
||||
"""Create a new client over SSL/TLS if necessary."""
|
||||
host = host or client_context.host
|
||||
port = port or client_context.port
|
||||
client_options: dict = client_context.default_client_options.copy()
|
||||
if client_context.replica_set_name and not directConnection:
|
||||
client_options["replicaSet"] = client_context.replica_set_name
|
||||
if directConnection is not None:
|
||||
client_options["directConnection"] = directConnection
|
||||
client_options.update(kwargs)
|
||||
|
||||
uri = _connection_string(host)
|
||||
auth_mech = kwargs.get("authMechanism", "")
|
||||
if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
|
||||
# Only add the default username or password if one is not provided.
|
||||
res = parse_uri(uri)
|
||||
if (
|
||||
not res["username"]
|
||||
and not res["password"]
|
||||
and "username" not in client_options
|
||||
and "password" not in client_options
|
||||
):
|
||||
client_options["username"] = db_user
|
||||
client_options["password"] = db_pwd
|
||||
client = MongoClient(uri, port, **client_options)
|
||||
if client._options.connect:
|
||||
client._connect()
|
||||
self.addCleanup(client.close)
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
def unmanaged_single_client_noauth(
|
||||
cls, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> MongoClient[dict]:
|
||||
"""Make a direct connection. Don't authenticate."""
|
||||
return cls._unmanaged_async_mongo_client(
|
||||
h, p, authenticate=False, directConnection=True, **kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unmanaged_single_client(
|
||||
cls, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> MongoClient[dict]:
|
||||
"""Make a direct connection. Don't authenticate."""
|
||||
return cls._unmanaged_async_mongo_client(h, p, directConnection=True, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def unmanaged_rs_client(cls, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
|
||||
"""Connect to the replica set and authenticate if necessary."""
|
||||
return cls._unmanaged_async_mongo_client(h, p, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def unmanaged_rs_client_noauth(
|
||||
cls, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> MongoClient[dict]:
|
||||
"""Make a direct connection. Don't authenticate."""
|
||||
return cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def unmanaged_rs_or_single_client_noauth(
|
||||
cls, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> MongoClient[dict]:
|
||||
"""Make a direct connection. Don't authenticate."""
|
||||
return cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def unmanaged_rs_or_single_client(
|
||||
cls, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> MongoClient[dict]:
|
||||
"""Make a direct connection. Don't authenticate."""
|
||||
return cls._unmanaged_async_mongo_client(h, p, **kwargs)
|
||||
|
||||
def single_client_noauth(
|
||||
self, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> MongoClient[dict]:
|
||||
"""Make a direct connection. Don't authenticate."""
|
||||
return self._async_mongo_client(h, p, authenticate=False, directConnection=True, **kwargs)
|
||||
|
||||
def single_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
|
||||
"""Make a direct connection, and authenticate if necessary."""
|
||||
return self._async_mongo_client(h, p, directConnection=True, **kwargs)
|
||||
|
||||
def rs_client_noauth(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
|
||||
"""Connect to the replica set. Don't authenticate."""
|
||||
return self._async_mongo_client(h, p, authenticate=False, **kwargs)
|
||||
|
||||
def rs_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
|
||||
"""Connect to the replica set and authenticate if necessary."""
|
||||
return self._async_mongo_client(h, p, **kwargs)
|
||||
|
||||
def rs_or_single_client_noauth(
|
||||
self, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> MongoClient[dict]:
|
||||
"""Connect to the replica set if there is one, otherwise the standalone.
|
||||
|
||||
Like rs_or_single_client, but does not authenticate.
|
||||
"""
|
||||
return self._async_mongo_client(h, p, authenticate=False, **kwargs)
|
||||
|
||||
def rs_or_single_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[Any]:
|
||||
"""Connect to the replica set if there is one, otherwise the standalone.
|
||||
|
||||
Authenticates if necessary.
|
||||
"""
|
||||
return self._async_mongo_client(h, p, **kwargs)
|
||||
|
||||
def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient:
|
||||
if not h and not p:
|
||||
client = MongoClient(**kwargs)
|
||||
else:
|
||||
client = MongoClient(h, p, **kwargs)
|
||||
self.addCleanup(client.close)
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
def unmanaged_simple_client(cls, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient:
|
||||
if not h and not p:
|
||||
client = MongoClient(**kwargs)
|
||||
else:
|
||||
client = MongoClient(h, p, **kwargs)
|
||||
return client
|
||||
|
||||
def disable_replication(self, client):
|
||||
"""Disable replication on all secondaries."""
|
||||
for h, p in client.secondaries:
|
||||
secondary = self.single_client(h, p)
|
||||
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn")
|
||||
|
||||
def enable_replication(self, client):
|
||||
"""Enable replication on all secondaries."""
|
||||
for h, p in client.secondaries:
|
||||
secondary = self.single_client(h, p)
|
||||
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off")
|
||||
|
||||
|
||||
class UnitTest(PyMongoTestCase):
|
||||
"""Async base class for TestCases that don't require a connection to MongoDB."""
|
||||
|
||||
@ -16,8 +16,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextlib
|
||||
import gc
|
||||
import multiprocessing
|
||||
import os
|
||||
@ -27,7 +25,6 @@ import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
@ -54,6 +51,8 @@ from test.helpers import (
|
||||
sanitize_reply,
|
||||
)
|
||||
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
@ -80,6 +79,12 @@ from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def _connection_string(h):
|
||||
if h.startswith(("mongodb://", "mongodb+srv://")):
|
||||
return h
|
||||
return f"mongodb://{h!s}"
|
||||
|
||||
|
||||
class AsyncClientContext:
|
||||
client: AsyncMongoClient
|
||||
|
||||
@ -230,6 +235,9 @@ class AsyncClientContext:
|
||||
if not await self._check_user_provided():
|
||||
await _create_user(self.client.admin, db_user, db_pwd)
|
||||
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
self.client = await self._connect(
|
||||
host,
|
||||
port,
|
||||
@ -256,6 +264,8 @@ class AsyncClientContext:
|
||||
if "setName" in hello:
|
||||
self.replica_set_name = str(hello["setName"])
|
||||
self.is_rs = True
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
if self.auth_enabled:
|
||||
# It doesn't matter which member we use as the seed here.
|
||||
self.client = pymongo.AsyncMongoClient(
|
||||
@ -320,6 +330,7 @@ class AsyncClientContext:
|
||||
hello = await mongos_client.admin.command(HelloCompat.LEGACY_CMD)
|
||||
if hello.get("msg") == "isdbgrid":
|
||||
self.mongoses.append(next_address)
|
||||
await mongos_client.close()
|
||||
|
||||
async def init(self):
|
||||
with self.conn_lock:
|
||||
@ -539,12 +550,6 @@ class AsyncClientContext:
|
||||
lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func
|
||||
)
|
||||
|
||||
def require_no_fips(self, func):
|
||||
"""Run a test only if the host does not have FIPS enabled."""
|
||||
return self._require(
|
||||
lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func
|
||||
)
|
||||
|
||||
def require_no_auth(self, func):
|
||||
"""Run a test only if the server is running without auth enabled."""
|
||||
return self._require(
|
||||
@ -932,6 +937,188 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
|
||||
self.fail(f"child timed out after {timeout}s (see traceback in logs): deadlock?")
|
||||
self.assertEqual(proc.exitcode, 0)
|
||||
|
||||
@classmethod
|
||||
async def _unmanaged_async_mongo_client(
|
||||
cls, host, port, authenticate=True, directConnection=None, **kwargs
|
||||
):
|
||||
"""Create a new client over SSL/TLS if necessary."""
|
||||
host = host or await async_client_context.host
|
||||
port = port or await async_client_context.port
|
||||
client_options: dict = async_client_context.default_client_options.copy()
|
||||
if async_client_context.replica_set_name and not directConnection:
|
||||
client_options["replicaSet"] = async_client_context.replica_set_name
|
||||
if directConnection is not None:
|
||||
client_options["directConnection"] = directConnection
|
||||
client_options.update(kwargs)
|
||||
|
||||
uri = _connection_string(host)
|
||||
auth_mech = kwargs.get("authMechanism", "")
|
||||
if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
|
||||
# Only add the default username or password if one is not provided.
|
||||
res = parse_uri(uri)
|
||||
if (
|
||||
not res["username"]
|
||||
and not res["password"]
|
||||
and "username" not in client_options
|
||||
and "password" not in client_options
|
||||
):
|
||||
client_options["username"] = db_user
|
||||
client_options["password"] = db_pwd
|
||||
client = AsyncMongoClient(uri, port, **client_options)
|
||||
if client._options.connect:
|
||||
await client.aconnect()
|
||||
return client
|
||||
|
||||
async def _async_mongo_client(
|
||||
self, host, port, authenticate=True, directConnection=None, **kwargs
|
||||
):
|
||||
"""Create a new client over SSL/TLS if necessary."""
|
||||
host = host or await async_client_context.host
|
||||
port = port or await async_client_context.port
|
||||
client_options: dict = async_client_context.default_client_options.copy()
|
||||
if async_client_context.replica_set_name and not directConnection:
|
||||
client_options["replicaSet"] = async_client_context.replica_set_name
|
||||
if directConnection is not None:
|
||||
client_options["directConnection"] = directConnection
|
||||
client_options.update(kwargs)
|
||||
|
||||
uri = _connection_string(host)
|
||||
auth_mech = kwargs.get("authMechanism", "")
|
||||
if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
|
||||
# Only add the default username or password if one is not provided.
|
||||
res = parse_uri(uri)
|
||||
if (
|
||||
not res["username"]
|
||||
and not res["password"]
|
||||
and "username" not in client_options
|
||||
and "password" not in client_options
|
||||
):
|
||||
client_options["username"] = db_user
|
||||
client_options["password"] = db_pwd
|
||||
client = AsyncMongoClient(uri, port, **client_options)
|
||||
if client._options.connect:
|
||||
await client.aconnect()
|
||||
self.addAsyncCleanup(client.close)
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def unmanaged_async_single_client_noauth(
|
||||
cls, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Make a direct connection. Don't authenticate."""
|
||||
return await cls._unmanaged_async_mongo_client(
|
||||
h, p, authenticate=False, directConnection=True, **kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def unmanaged_async_single_client(
|
||||
cls, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Make a direct connection. Don't authenticate."""
|
||||
return await cls._unmanaged_async_mongo_client(h, p, directConnection=True, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def unmanaged_async_rs_client(
|
||||
cls, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Connect to the replica set and authenticate if necessary."""
|
||||
return await cls._unmanaged_async_mongo_client(h, p, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def unmanaged_async_rs_client_noauth(
|
||||
cls, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Make a direct connection. Don't authenticate."""
|
||||
return await cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def unmanaged_async_rs_or_single_client_noauth(
|
||||
cls, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Make a direct connection. Don't authenticate."""
|
||||
return await cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def unmanaged_async_rs_or_single_client(
|
||||
cls, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Make a direct connection. Don't authenticate."""
|
||||
return await cls._unmanaged_async_mongo_client(h, p, **kwargs)
|
||||
|
||||
async def async_single_client_noauth(
|
||||
self, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Make a direct connection. Don't authenticate."""
|
||||
return await self._async_mongo_client(
|
||||
h, p, authenticate=False, directConnection=True, **kwargs
|
||||
)
|
||||
|
||||
async def async_single_client(
|
||||
self, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Make a direct connection, and authenticate if necessary."""
|
||||
return await self._async_mongo_client(h, p, directConnection=True, **kwargs)
|
||||
|
||||
async def async_rs_client_noauth(
|
||||
self, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Connect to the replica set. Don't authenticate."""
|
||||
return await self._async_mongo_client(h, p, authenticate=False, **kwargs)
|
||||
|
||||
async def async_rs_client(
|
||||
self, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Connect to the replica set and authenticate if necessary."""
|
||||
return await self._async_mongo_client(h, p, **kwargs)
|
||||
|
||||
async def async_rs_or_single_client_noauth(
|
||||
self, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Connect to the replica set if there is one, otherwise the standalone.
|
||||
|
||||
Like rs_or_single_client, but does not authenticate.
|
||||
"""
|
||||
return await self._async_mongo_client(h, p, authenticate=False, **kwargs)
|
||||
|
||||
async def async_rs_or_single_client(
|
||||
self, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[Any]:
|
||||
"""Connect to the replica set if there is one, otherwise the standalone.
|
||||
|
||||
Authenticates if necessary.
|
||||
"""
|
||||
return await self._async_mongo_client(h, p, **kwargs)
|
||||
|
||||
def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> AsyncMongoClient:
|
||||
if not h and not p:
|
||||
client = AsyncMongoClient(**kwargs)
|
||||
else:
|
||||
client = AsyncMongoClient(h, p, **kwargs)
|
||||
self.addAsyncCleanup(client.close)
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
def unmanaged_simple_client(
|
||||
cls, h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient:
|
||||
if not h and not p:
|
||||
client = AsyncMongoClient(**kwargs)
|
||||
else:
|
||||
client = AsyncMongoClient(h, p, **kwargs)
|
||||
return client
|
||||
|
||||
async def disable_replication(self, client):
|
||||
"""Disable replication on all secondaries."""
|
||||
for h, p in client.secondaries:
|
||||
secondary = await self.async_single_client(h, p)
|
||||
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn")
|
||||
|
||||
async def enable_replication(self, client):
|
||||
"""Enable replication on all secondaries."""
|
||||
for h, p in client.secondaries:
|
||||
secondary = await self.async_single_client(h, p)
|
||||
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off")
|
||||
|
||||
|
||||
class AsyncUnitTest(AsyncPyMongoTestCase):
|
||||
"""Async base class for TestCases that don't require a connection to MongoDB."""
|
||||
|
||||
@ -23,16 +23,14 @@ from urllib.parse import quote_plus
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, SkipTest, async_client_context, unittest
|
||||
from test.utils import (
|
||||
AllowListEventListener,
|
||||
async_rs_or_single_client,
|
||||
async_rs_or_single_client_noauth,
|
||||
async_single_client,
|
||||
async_single_client_noauth,
|
||||
delay,
|
||||
ignore_deprecations,
|
||||
from test.asynchronous import (
|
||||
AsyncIntegrationTest,
|
||||
AsyncPyMongoTestCase,
|
||||
SkipTest,
|
||||
async_client_context,
|
||||
unittest,
|
||||
)
|
||||
from test.utils import AllowListEventListener, delay, ignore_deprecations
|
||||
|
||||
from pymongo import AsyncMongoClient, monitoring
|
||||
from pymongo.asynchronous.auth import HAVE_KERBEROS
|
||||
@ -81,7 +79,7 @@ class AutoAuthenticateThread(threading.Thread):
|
||||
self.success = True
|
||||
|
||||
|
||||
class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
|
||||
class TestGSSAPI(AsyncPyMongoTestCase):
|
||||
mech_properties: str
|
||||
service_realm_required: bool
|
||||
|
||||
@ -138,7 +136,7 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
if not self.service_realm_required:
|
||||
# Without authMechanismProperties.
|
||||
client = AsyncMongoClient(
|
||||
client = self.simple_client(
|
||||
GSSAPI_HOST,
|
||||
GSSAPI_PORT,
|
||||
username=GSSAPI_PRINCIPAL,
|
||||
@ -149,11 +147,11 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
|
||||
await client[GSSAPI_DB].collection.find_one()
|
||||
|
||||
# Log in using URI, without authMechanismProperties.
|
||||
client = AsyncMongoClient(uri)
|
||||
client = self.simple_client(uri)
|
||||
await client[GSSAPI_DB].collection.find_one()
|
||||
|
||||
# Authenticate with authMechanismProperties.
|
||||
client = AsyncMongoClient(
|
||||
client = self.simple_client(
|
||||
GSSAPI_HOST,
|
||||
GSSAPI_PORT,
|
||||
username=GSSAPI_PRINCIPAL,
|
||||
@ -166,14 +164,14 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# Log in using URI, with authMechanismProperties.
|
||||
mech_uri = uri + f"&authMechanismProperties={self.mech_properties}"
|
||||
client = AsyncMongoClient(mech_uri)
|
||||
client = self.simple_client(mech_uri)
|
||||
await client[GSSAPI_DB].collection.find_one()
|
||||
|
||||
set_name = async_client_context.replica_set_name
|
||||
if set_name:
|
||||
if not self.service_realm_required:
|
||||
# Without authMechanismProperties
|
||||
client = AsyncMongoClient(
|
||||
client = self.simple_client(
|
||||
GSSAPI_HOST,
|
||||
GSSAPI_PORT,
|
||||
username=GSSAPI_PRINCIPAL,
|
||||
@ -185,11 +183,11 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
|
||||
await client[GSSAPI_DB].list_collection_names()
|
||||
|
||||
uri = uri + f"&replicaSet={set_name!s}"
|
||||
client = AsyncMongoClient(uri)
|
||||
client = self.simple_client(uri)
|
||||
await client[GSSAPI_DB].list_collection_names()
|
||||
|
||||
# With authMechanismProperties
|
||||
client = AsyncMongoClient(
|
||||
client = self.simple_client(
|
||||
GSSAPI_HOST,
|
||||
GSSAPI_PORT,
|
||||
username=GSSAPI_PRINCIPAL,
|
||||
@ -202,13 +200,13 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
|
||||
await client[GSSAPI_DB].list_collection_names()
|
||||
|
||||
mech_uri = mech_uri + f"&replicaSet={set_name!s}"
|
||||
client = AsyncMongoClient(mech_uri)
|
||||
client = self.simple_client(mech_uri)
|
||||
await client[GSSAPI_DB].list_collection_names()
|
||||
|
||||
@ignore_deprecations
|
||||
@async_client_context.require_sync
|
||||
async def test_gssapi_threaded(self):
|
||||
client = AsyncMongoClient(
|
||||
client = self.simple_client(
|
||||
GSSAPI_HOST,
|
||||
GSSAPI_PORT,
|
||||
username=GSSAPI_PRINCIPAL,
|
||||
@ -244,7 +242,7 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
set_name = async_client_context.replica_set_name
|
||||
if set_name:
|
||||
client = AsyncMongoClient(
|
||||
client = self.simple_client(
|
||||
GSSAPI_HOST,
|
||||
GSSAPI_PORT,
|
||||
username=GSSAPI_PRINCIPAL,
|
||||
@ -267,14 +265,14 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(thread.success)
|
||||
|
||||
|
||||
class TestSASLPlain(unittest.IsolatedAsyncioTestCase):
|
||||
class TestSASLPlain(AsyncPyMongoTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not SASL_HOST or not SASL_USER or not SASL_PASS:
|
||||
raise SkipTest("Must set SASL_HOST, SASL_USER, and SASL_PASS to test SASL")
|
||||
|
||||
async def test_sasl_plain(self):
|
||||
client = AsyncMongoClient(
|
||||
client = self.simple_client(
|
||||
SASL_HOST,
|
||||
SASL_PORT,
|
||||
username=SASL_USER,
|
||||
@ -293,12 +291,12 @@ class TestSASLPlain(unittest.IsolatedAsyncioTestCase):
|
||||
SASL_PORT,
|
||||
SASL_DB,
|
||||
)
|
||||
client = AsyncMongoClient(uri)
|
||||
client = self.simple_client(uri)
|
||||
await client.ldap.test.find_one()
|
||||
|
||||
set_name = async_client_context.replica_set_name
|
||||
if set_name:
|
||||
client = AsyncMongoClient(
|
||||
client = self.simple_client(
|
||||
SASL_HOST,
|
||||
SASL_PORT,
|
||||
replicaSet=set_name,
|
||||
@ -317,7 +315,7 @@ class TestSASLPlain(unittest.IsolatedAsyncioTestCase):
|
||||
SASL_DB,
|
||||
str(set_name),
|
||||
)
|
||||
client = AsyncMongoClient(uri)
|
||||
client = self.simple_client(uri)
|
||||
await client.ldap.test.find_one()
|
||||
|
||||
async def test_sasl_plain_bad_credentials(self):
|
||||
@ -331,8 +329,8 @@ class TestSASLPlain(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
return uri
|
||||
|
||||
bad_user = AsyncMongoClient(auth_string("not-user", SASL_PASS))
|
||||
bad_pwd = AsyncMongoClient(auth_string(SASL_USER, "not-pwd"))
|
||||
bad_user = self.simple_client(auth_string("not-user", SASL_PASS))
|
||||
bad_pwd = self.simple_client(auth_string(SASL_USER, "not-pwd"))
|
||||
# OperationFailure raised upon connecting.
|
||||
with self.assertRaises(OperationFailure):
|
||||
await bad_user.admin.command("ping")
|
||||
@ -356,7 +354,7 @@ class TestSCRAMSHA1(AsyncIntegrationTest):
|
||||
async def test_scram_sha1(self):
|
||||
host, port = await async_client_context.host, await async_client_context.port
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
"mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" % (host, port)
|
||||
)
|
||||
await client.pymongo_test.command("dbstats")
|
||||
@ -367,7 +365,7 @@ class TestSCRAMSHA1(AsyncIntegrationTest):
|
||||
"@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1"
|
||||
"&replicaSet=%s" % (host, port, async_client_context.replica_set_name)
|
||||
)
|
||||
client = await async_single_client_noauth(uri)
|
||||
client = await self.async_single_client_noauth(uri)
|
||||
await client.pymongo_test.command("dbstats")
|
||||
db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY)
|
||||
await db.command("dbstats")
|
||||
@ -395,7 +393,7 @@ class TestSCRAM(AsyncIntegrationTest):
|
||||
"testscram", "sha256", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"]
|
||||
)
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="sha256", password="pwd", authSource="testscram", event_listeners=[listener]
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
@ -432,38 +430,38 @@ class TestSCRAM(AsyncIntegrationTest):
|
||||
)
|
||||
|
||||
# Step 2: verify auth success cases
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="sha1", password="pwd", authSource="testscram"
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1"
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="sha256", password="pwd", authSource="testscram"
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256"
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
# Step 2: SCRAM-SHA-1 and SCRAM-SHA-256
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1"
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256"
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
self.listener.reset()
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="both", password="pwd", authSource="testscram", event_listeners=[self.listener]
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
@ -476,19 +474,19 @@ class TestSCRAM(AsyncIntegrationTest):
|
||||
self.assertEqual(started.command.get("mechanism"), "SCRAM-SHA-256")
|
||||
|
||||
# Step 3: verify auth failure conditions
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256"
|
||||
)
|
||||
with self.assertRaises(OperationFailure):
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1"
|
||||
)
|
||||
with self.assertRaises(OperationFailure):
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="not-a-user", password="pwd", authSource="testscram"
|
||||
)
|
||||
with self.assertRaises(OperationFailure):
|
||||
@ -501,7 +499,7 @@ class TestSCRAM(AsyncIntegrationTest):
|
||||
port,
|
||||
async_client_context.replica_set_name,
|
||||
)
|
||||
client = await async_single_client_noauth(uri)
|
||||
client = await self.async_single_client_noauth(uri)
|
||||
await client.testscram.command("dbstats")
|
||||
db = client.get_database("testscram", read_preference=ReadPreference.SECONDARY)
|
||||
await db.command("dbstats")
|
||||
@ -521,12 +519,12 @@ class TestSCRAM(AsyncIntegrationTest):
|
||||
"testscram", "IX", "IX", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"]
|
||||
)
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="\u2168", password="\u2163", authSource="testscram"
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="\u2168",
|
||||
password="\u2163",
|
||||
authSource="testscram",
|
||||
@ -534,17 +532,17 @@ class TestSCRAM(AsyncIntegrationTest):
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="\u2168", password="IV", authSource="testscram"
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="IX", password="I\u00ADX", authSource="testscram"
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="IX",
|
||||
password="I\u00ADX",
|
||||
authSource="testscram",
|
||||
@ -552,31 +550,31 @@ class TestSCRAM(AsyncIntegrationTest):
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
username="IX", password="IX", authSource="testscram", authMechanism="SCRAM-SHA-256"
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
"mongodb://\u2168:\u2163@%s:%d/testscram" % (host, port)
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
"mongodb://\u2168:IV@%s:%d/testscram" % (host, port)
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
"mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port)
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
client = await async_rs_or_single_client_noauth(
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
"mongodb://IX:IX@%s:%d/testscram" % (host, port)
|
||||
)
|
||||
await client.testscram.command("dbstats")
|
||||
|
||||
async def test_cache(self):
|
||||
client = await async_single_client()
|
||||
client = await self.async_single_client()
|
||||
credentials = client.options.pool_options._credentials
|
||||
cache = credentials.cache
|
||||
self.assertIsNotNone(cache)
|
||||
@ -601,8 +599,7 @@ class TestSCRAM(AsyncIntegrationTest):
|
||||
await coll.insert_one({"_id": 1})
|
||||
|
||||
# The first thread to call find() will authenticate
|
||||
client = await async_rs_or_single_client()
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client()
|
||||
coll = client.db.test
|
||||
threads = []
|
||||
for _ in range(4):
|
||||
@ -631,7 +628,9 @@ class TestAuthURIOptions(AsyncIntegrationTest):
|
||||
async def test_uri_options(self):
|
||||
# Test default to admin
|
||||
host, port = await async_client_context.host, await async_client_context.port
|
||||
client = await async_rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))
|
||||
client = await self.async_rs_or_single_client_noauth(
|
||||
"mongodb://admin:pass@%s:%d" % (host, port)
|
||||
)
|
||||
self.assertTrue(await client.admin.command("dbstats"))
|
||||
|
||||
if async_client_context.is_rs:
|
||||
@ -640,14 +639,14 @@ class TestAuthURIOptions(AsyncIntegrationTest):
|
||||
port,
|
||||
async_client_context.replica_set_name,
|
||||
)
|
||||
client = await async_single_client_noauth(uri)
|
||||
client = await self.async_single_client_noauth(uri)
|
||||
self.assertTrue(await client.admin.command("dbstats"))
|
||||
db = client.get_database("admin", read_preference=ReadPreference.SECONDARY)
|
||||
self.assertTrue(await db.command("dbstats"))
|
||||
|
||||
# Test explicit database
|
||||
uri = "mongodb://user:pass@%s:%d/pymongo_test" % (host, port)
|
||||
client = await async_rs_or_single_client_noauth(uri)
|
||||
client = await self.async_rs_or_single_client_noauth(uri)
|
||||
with self.assertRaises(OperationFailure):
|
||||
await client.admin.command("dbstats")
|
||||
self.assertTrue(await client.pymongo_test.command("dbstats"))
|
||||
@ -658,7 +657,7 @@ class TestAuthURIOptions(AsyncIntegrationTest):
|
||||
port,
|
||||
async_client_context.replica_set_name,
|
||||
)
|
||||
client = await async_single_client_noauth(uri)
|
||||
client = await self.async_single_client_noauth(uri)
|
||||
with self.assertRaises(OperationFailure):
|
||||
await client.admin.command("dbstats")
|
||||
self.assertTrue(await client.pymongo_test.command("dbstats"))
|
||||
@ -667,7 +666,7 @@ class TestAuthURIOptions(AsyncIntegrationTest):
|
||||
|
||||
# Test authSource
|
||||
uri = "mongodb://user:pass@%s:%d/pymongo_test2?authSource=pymongo_test" % (host, port)
|
||||
client = await async_rs_or_single_client_noauth(uri)
|
||||
client = await self.async_rs_or_single_client_noauth(uri)
|
||||
with self.assertRaises(OperationFailure):
|
||||
await client.pymongo_test2.command("dbstats")
|
||||
self.assertTrue(await client.pymongo_test.command("dbstats"))
|
||||
@ -677,7 +676,7 @@ class TestAuthURIOptions(AsyncIntegrationTest):
|
||||
"mongodb://user:pass@%s:%d/pymongo_test2?replicaSet="
|
||||
"%s;authSource=pymongo_test" % (host, port, async_client_context.replica_set_name)
|
||||
)
|
||||
client = await async_single_client_noauth(uri)
|
||||
client = await self.async_single_client_noauth(uri)
|
||||
with self.assertRaises(OperationFailure):
|
||||
await client.pymongo_test2.command("dbstats")
|
||||
self.assertTrue(await client.pymongo_test.command("dbstats"))
|
||||
|
||||
@ -20,6 +20,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from test.asynchronous import AsyncPyMongoTestCase
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
@ -34,7 +35,7 @@ _IS_SYNC = False
|
||||
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth")
|
||||
|
||||
|
||||
class TestAuthSpec(unittest.IsolatedAsyncioTestCase):
|
||||
class TestAuthSpec(AsyncPyMongoTestCase):
|
||||
pass
|
||||
|
||||
|
||||
@ -54,7 +55,7 @@ def create_test(test_case):
|
||||
warnings.simplefilter("default")
|
||||
self.assertRaises(Exception, AsyncMongoClient, uri, connect=False)
|
||||
else:
|
||||
client = AsyncMongoClient(uri, connect=False)
|
||||
client = self.simple_client(uri, connect=False)
|
||||
credentials = client.options.pool_options._credentials
|
||||
if credential is None:
|
||||
self.assertIsNone(credentials)
|
||||
|
||||
@ -24,23 +24,14 @@ from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, remove_all_users, unittest
|
||||
from test.utils import (
|
||||
async_rs_or_single_client_noauth,
|
||||
async_single_client,
|
||||
async_wait_until,
|
||||
)
|
||||
from test.utils import async_wait_until
|
||||
|
||||
from bson.binary import Binary, UuidRepresentation
|
||||
from bson.codec_options import CodecOptions
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.common import partition_node
|
||||
from pymongo.errors import (
|
||||
BulkWriteError,
|
||||
ConfigurationError,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.errors import BulkWriteError, ConfigurationError, InvalidOperation, OperationFailure
|
||||
from pymongo.operations import *
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
@ -915,7 +906,7 @@ class AsyncTestBulkAuthorization(AsyncBulkAuthorizationTestBase):
|
||||
async def test_readonly(self):
|
||||
# We test that an authorization failure aborts the batch and is raised
|
||||
# as OperationFailure.
|
||||
cli = await async_rs_or_single_client_noauth(
|
||||
cli = await self.async_rs_or_single_client_noauth(
|
||||
username="readonly", password="pw", authSource="pymongo_test"
|
||||
)
|
||||
coll = cli.pymongo_test.test
|
||||
@ -926,7 +917,7 @@ class AsyncTestBulkAuthorization(AsyncBulkAuthorizationTestBase):
|
||||
async def test_no_remove(self):
|
||||
# We test that an authorization failure aborts the batch and is raised
|
||||
# as OperationFailure.
|
||||
cli = await async_rs_or_single_client_noauth(
|
||||
cli = await self.async_rs_or_single_client_noauth(
|
||||
username="noremove", password="pw", authSource="pymongo_test"
|
||||
)
|
||||
coll = cli.pymongo_test.test
|
||||
@ -954,7 +945,7 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
|
||||
if cls.w is not None and cls.w > 1:
|
||||
for member in (await async_client_context.hello)["hosts"]:
|
||||
if member != (await async_client_context.hello)["primary"]:
|
||||
cls.secondary = await async_single_client(*partition_node(member))
|
||||
cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member))
|
||||
break
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -28,12 +28,17 @@ from typing import no_type_check
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, Version, async_client_context, unittest
|
||||
from test.asynchronous import (
|
||||
AsyncIntegrationTest,
|
||||
AsyncPyMongoTestCase,
|
||||
Version,
|
||||
async_client_context,
|
||||
unittest,
|
||||
)
|
||||
from test.unified_format import generate_test_classes
|
||||
from test.utils import (
|
||||
AllowListEventListener,
|
||||
EventListener,
|
||||
async_rs_or_single_client,
|
||||
async_wait_until,
|
||||
)
|
||||
|
||||
@ -69,8 +74,7 @@ class TestAsyncChangeStreamBase(AsyncIntegrationTest):
|
||||
async def client_with_listener(self, *commands):
|
||||
"""Return a client with a AllowListEventListener."""
|
||||
listener = AllowListEventListener(*commands)
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
return client, listener
|
||||
|
||||
def watched_collection(self, *args, **kwargs):
|
||||
@ -176,7 +180,7 @@ class APITestsMixin:
|
||||
@no_type_check
|
||||
async def test_try_next_runs_one_getmore(self):
|
||||
listener = EventListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
# Connect to the cluster.
|
||||
await client.admin.command("ping")
|
||||
listener.reset()
|
||||
@ -234,7 +238,7 @@ class APITestsMixin:
|
||||
@no_type_check
|
||||
async def test_batch_size_is_honored(self):
|
||||
listener = EventListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
# Connect to the cluster.
|
||||
await client.admin.command("ping")
|
||||
listener.reset()
|
||||
@ -481,7 +485,9 @@ class ProseSpecTestsMixin:
|
||||
@no_type_check
|
||||
async def _client_with_listener(self, *commands):
|
||||
listener = AllowListEventListener(*commands)
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
client = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client(
|
||||
event_listeners=[listener]
|
||||
)
|
||||
self.addAsyncCleanup(client.close)
|
||||
return client, listener
|
||||
|
||||
@ -1131,7 +1137,7 @@ class TestAllLegacyScenarios(AsyncIntegrationTest):
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.listener = AllowListEventListener("aggregate", "getMore")
|
||||
cls.client = await async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -27,7 +27,6 @@ from test.asynchronous import (
|
||||
)
|
||||
from test.utils import (
|
||||
OvertCommandListener,
|
||||
async_rs_or_single_client,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -39,7 +38,6 @@ from pymongo.errors import (
|
||||
InvalidOperation,
|
||||
NetworkTimeout,
|
||||
)
|
||||
from pymongo.monitoring import *
|
||||
from pymongo.operations import *
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
@ -97,8 +95,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
@async_client_context.require_no_serverless
|
||||
async def test_batch_splits_if_num_operations_too_large(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
models = []
|
||||
for _ in range(self.max_write_batch_size + 1):
|
||||
@ -123,8 +120,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
@async_client_context.require_no_serverless
|
||||
async def test_batch_splits_if_ops_payload_too_large(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
models = []
|
||||
num_models = int(self.max_message_size_bytes / self.max_bson_object_size + 1)
|
||||
@ -157,11 +153,10 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
@async_client_context.require_failCommand_fail_point
|
||||
async def test_collects_write_concern_errors_across_batches(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(
|
||||
client = await self.async_rs_or_single_client(
|
||||
event_listeners=[listener],
|
||||
retryWrites=False,
|
||||
)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
fail_command = {
|
||||
"configureFailPoint": "failCommand",
|
||||
@ -200,8 +195,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
@async_client_context.require_no_serverless
|
||||
async def test_collects_write_errors_across_batches_unordered(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addAsyncCleanup(collection.drop)
|
||||
@ -231,8 +225,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
@async_client_context.require_no_serverless
|
||||
async def test_collects_write_errors_across_batches_ordered(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addAsyncCleanup(collection.drop)
|
||||
@ -262,8 +255,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
@async_client_context.require_no_serverless
|
||||
async def test_handles_cursor_requiring_getMore(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addAsyncCleanup(collection.drop)
|
||||
@ -304,8 +296,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
@async_client_context.require_no_standalone
|
||||
async def test_handles_cursor_requiring_getMore_within_transaction(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addAsyncCleanup(collection.drop)
|
||||
@ -348,8 +339,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
@async_client_context.require_failCommand_fail_point
|
||||
async def test_handles_getMore_error(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addAsyncCleanup(collection.drop)
|
||||
@ -403,8 +393,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
@async_client_context.require_no_serverless
|
||||
async def test_returns_error_if_unacknowledged_too_large_insert(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
b_repeated = "b" * self.max_bson_object_size
|
||||
|
||||
@ -460,8 +449,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
@async_client_context.require_no_serverless
|
||||
async def test_no_batch_splits_if_new_namespace_is_not_too_large(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
num_models, models = await self._setup_namespace_test_models()
|
||||
models.append(
|
||||
@ -492,8 +480,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
@async_client_context.require_no_serverless
|
||||
async def test_batch_splits_if_new_namespace_is_too_large(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
num_models, models = await self._setup_namespace_test_models()
|
||||
c_repeated = "c" * 200
|
||||
@ -530,8 +517,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
@async_client_context.require_version_min(8, 0, 0, -24)
|
||||
@async_client_context.require_no_serverless
|
||||
async def test_returns_error_if_no_writes_can_be_added_to_ops(self):
|
||||
client = await async_rs_or_single_client()
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client()
|
||||
|
||||
# Document too large.
|
||||
b_repeated = "b" * self.max_message_size_bytes
|
||||
@ -554,8 +540,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
key_vault_namespace="db.coll",
|
||||
kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}},
|
||||
)
|
||||
client = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
|
||||
models = [InsertOne(namespace="db.coll", document={"a": "b"})]
|
||||
with self.assertRaises(InvalidOperation) as context:
|
||||
@ -580,7 +565,7 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
|
||||
async def test_timeout_in_multi_batch_bulk_write(self):
|
||||
_OVERHEAD = 500
|
||||
|
||||
internal_client = await async_rs_or_single_client(timeoutMS=None)
|
||||
internal_client = await self.async_rs_or_single_client(timeoutMS=None)
|
||||
self.addAsyncCleanup(internal_client.close)
|
||||
|
||||
collection = internal_client.db["coll"]
|
||||
@ -605,14 +590,13 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
|
||||
)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(
|
||||
client = await self.async_rs_or_single_client(
|
||||
event_listeners=[listener],
|
||||
readConcernLevel="majority",
|
||||
readPreference="primary",
|
||||
timeoutMS=2000,
|
||||
w="majority",
|
||||
)
|
||||
self.addAsyncCleanup(client.close)
|
||||
await client.admin.command("ping") # Init the client first.
|
||||
with self.assertRaises(ClientBulkWriteException) as context:
|
||||
await client.bulk_write(models=models)
|
||||
|
||||
@ -30,6 +30,7 @@ sys.path[0:0] = [""]
|
||||
from test import unittest
|
||||
from test.asynchronous import ( # TODO: fix sync imports in PYTHON-4528
|
||||
AsyncIntegrationTest,
|
||||
AsyncUnitTest,
|
||||
async_client_context,
|
||||
)
|
||||
from test.utils import (
|
||||
@ -37,8 +38,6 @@ from test.utils import (
|
||||
EventListener,
|
||||
async_get_pool,
|
||||
async_is_mongos,
|
||||
async_rs_or_single_client,
|
||||
async_single_client,
|
||||
async_wait_until,
|
||||
wait_until,
|
||||
)
|
||||
@ -82,14 +81,20 @@ from pymongo.write_concern import WriteConcern
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TestCollectionNoConnect(unittest.TestCase):
|
||||
class TestCollectionNoConnect(AsyncUnitTest):
|
||||
"""Test Collection features on a client that does not connect."""
|
||||
|
||||
db: AsyncDatabase
|
||||
client: AsyncMongoClient
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.db = AsyncMongoClient(connect=False).pymongo_test
|
||||
async def _setup_class(cls):
|
||||
cls.client = AsyncMongoClient(connect=False)
|
||||
cls.db = cls.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.close()
|
||||
|
||||
def test_collection(self):
|
||||
self.assertRaises(TypeError, AsyncCollection, self.db, 5)
|
||||
@ -1819,8 +1824,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
|
||||
# Insert enough documents to require more than one batch
|
||||
await self.db.test.insert_many([{"i": i} for i in range(150)])
|
||||
|
||||
client = await async_rs_or_single_client(maxPoolSize=1)
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(maxPoolSize=1)
|
||||
pool = await async_get_pool(client)
|
||||
|
||||
# Make sure the socket is returned after exhaustion.
|
||||
@ -2100,7 +2104,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
|
||||
|
||||
async def test_find_one_and_write_concern(self):
|
||||
listener = EventListener()
|
||||
db = (await async_single_client(event_listeners=[listener]))[self.db.name]
|
||||
db = (await self.async_single_client(event_listeners=[listener]))[self.db.name]
|
||||
# non-default WriteConcern.
|
||||
c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0))
|
||||
# default WriteConcern.
|
||||
|
||||
@ -34,7 +34,7 @@ from test.utils import (
|
||||
AllowListEventListener,
|
||||
EventListener,
|
||||
OvertCommandListener,
|
||||
async_rs_or_single_client,
|
||||
delay,
|
||||
ignore_deprecations,
|
||||
wait_until,
|
||||
)
|
||||
@ -45,7 +45,7 @@ from pymongo import ASCENDING, DESCENDING
|
||||
from pymongo.asynchronous.cursor import AsyncCursor, CursorType
|
||||
from pymongo.asynchronous.helpers import anext
|
||||
from pymongo.collation import Collation
|
||||
from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure
|
||||
from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure, PyMongoError
|
||||
from pymongo.operations import _IndexList
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
@ -232,7 +232,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
self.assertEqual(90, cursor._max_await_time_ms)
|
||||
|
||||
listener = AllowListEventListener("find", "getMore")
|
||||
coll = (await async_rs_or_single_client(event_listeners=[listener]))[
|
||||
coll = (await self.async_rs_or_single_client(event_listeners=[listener]))[
|
||||
self.db.name
|
||||
].pymongo_test
|
||||
|
||||
@ -353,8 +353,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
async def test_explain_with_read_concern(self):
|
||||
# Do not add readConcern level to explain.
|
||||
listener = AllowListEventListener("explain")
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local"))
|
||||
self.assertTrue(await coll.find().explain())
|
||||
started = listener.started_events
|
||||
@ -1261,8 +1260,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
await self.client._process_periodic_tasks()
|
||||
|
||||
listener = AllowListEventListener("killCursors")
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
coll = client[self.db.name].test_close_kills_cursors
|
||||
|
||||
# Add some test data.
|
||||
@ -1300,8 +1298,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
@async_client_context.require_failCommand_appName
|
||||
async def test_timeout_kills_cursor_asynchronously(self):
|
||||
listener = AllowListEventListener("killCursors")
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
coll = client[self.db.name].test_timeout_kills_cursor
|
||||
|
||||
# Add some test data.
|
||||
@ -1358,8 +1355,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
|
||||
async def test_getMore_does_not_send_readPreference(self):
|
||||
listener = AllowListEventListener("find", "getMore")
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
# We never send primary read preference so override the default.
|
||||
coll = client[self.db.name].get_collection(
|
||||
"test", read_preference=ReadPreference.PRIMARY_PREFERRED
|
||||
@ -1415,6 +1411,18 @@ class TestCursor(AsyncIntegrationTest):
|
||||
docs = await c.to_list(3)
|
||||
self.assertEqual(len(docs), 2)
|
||||
|
||||
async def test_to_list_csot_applied(self):
|
||||
client = await self.async_single_client(timeoutMS=500)
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(2):
|
||||
await client.admin.command("ping")
|
||||
coll = client.pymongo.test
|
||||
await coll.insert_many([{} for _ in range(5)])
|
||||
cursor = coll.find({"$where": delay(1)})
|
||||
with self.assertRaises(PyMongoError) as ctx:
|
||||
await cursor.to_list()
|
||||
self.assertTrue(ctx.exception.timeout)
|
||||
|
||||
@async_client_context.require_change_streams
|
||||
async def test_command_cursor_to_list(self):
|
||||
# Set maxAwaitTimeMS=1 to speed up the test.
|
||||
@ -1444,6 +1452,25 @@ class TestCursor(AsyncIntegrationTest):
|
||||
result = await db.test.aggregate([pipeline])
|
||||
self.assertEqual(len(await result.to_list(1)), 1)
|
||||
|
||||
@async_client_context.require_failCommand_blockConnection
|
||||
async def test_command_cursor_to_list_csot_applied(self):
|
||||
client = await self.async_single_client(timeoutMS=500)
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(2):
|
||||
await client.admin.command("ping")
|
||||
coll = client.pymongo.test
|
||||
await coll.insert_many([{} for _ in range(5)])
|
||||
fail_command = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 5},
|
||||
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 1000},
|
||||
}
|
||||
cursor = await coll.aggregate([], batchSize=1)
|
||||
async with self.fail_point(fail_command):
|
||||
with self.assertRaises(PyMongoError) as ctx:
|
||||
await cursor.to_list()
|
||||
self.assertTrue(ctx.exception.timeout)
|
||||
|
||||
|
||||
class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
async def test_find_raw(self):
|
||||
@ -1463,7 +1490,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
await c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
async with client.start_session() as session:
|
||||
async with await session.start_transaction():
|
||||
batches = await (
|
||||
@ -1493,7 +1520,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
await c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
async with self.fail_point(
|
||||
{"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}}
|
||||
):
|
||||
@ -1514,7 +1541,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
await c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
db = client[self.db.name]
|
||||
async with client.start_session(snapshot=True) as session:
|
||||
await db.test.distinct("x", {}, session=session)
|
||||
@ -1577,7 +1604,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
|
||||
async def test_monitoring(self):
|
||||
listener = EventListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
c = client.pymongo_test.test
|
||||
await c.drop()
|
||||
await c.insert_many([{"_id": i} for i in range(10)])
|
||||
@ -1643,7 +1670,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
|
||||
await c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
async with client.start_session() as session:
|
||||
async with await session.start_transaction():
|
||||
batches = await (
|
||||
@ -1674,7 +1701,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
|
||||
await c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
async with self.fail_point(
|
||||
{"mode": {"times": 1}, "data": {"failCommands": ["aggregate"], "closeConnection": True}}
|
||||
):
|
||||
@ -1698,7 +1725,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
|
||||
await c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
db = client[self.db.name]
|
||||
async with client.start_session(snapshot=True) as session:
|
||||
await db.test.distinct("x", {}, session=session)
|
||||
@ -1744,7 +1771,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
|
||||
|
||||
async def test_monitoring(self):
|
||||
listener = EventListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
c = client.pymongo_test.test
|
||||
await c.drop()
|
||||
await c.insert_many([{"_id": i} for i in range(10)])
|
||||
@ -1788,8 +1815,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
|
||||
@async_client_context.require_no_mongos
|
||||
async def test_exhaust_cursor_db_set(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
c = client.pymongo_test.test
|
||||
await c.delete_many({})
|
||||
await c.insert_many([{"_id": i} for i in range(3)])
|
||||
|
||||
@ -29,7 +29,6 @@ from test.test_custom_types import DECIMAL_CODECOPTS
|
||||
from test.utils import (
|
||||
IMPOSSIBLE_WRITE_CONCERN,
|
||||
OvertCommandListener,
|
||||
async_rs_or_single_client,
|
||||
async_wait_until,
|
||||
)
|
||||
|
||||
@ -208,7 +207,7 @@ class TestDatabase(AsyncIntegrationTest):
|
||||
|
||||
async def test_list_collection_names_filter(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
db = client[self.db.name]
|
||||
await db.capped.drop()
|
||||
await db.create_collection("capped", capped=True, size=4096)
|
||||
@ -235,8 +234,7 @@ class TestDatabase(AsyncIntegrationTest):
|
||||
|
||||
async def test_check_exists(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
db = client[self.db.name]
|
||||
await db.drop_collection("unique")
|
||||
await db.create_collection("unique", check_exists=True)
|
||||
@ -326,7 +324,7 @@ class TestDatabase(AsyncIntegrationTest):
|
||||
await self.client.drop_database("pymongo_test")
|
||||
|
||||
async def test_list_collection_names_single_socket(self):
|
||||
client = await async_rs_or_single_client(maxPoolSize=1)
|
||||
client = await self.async_rs_or_single_client(maxPoolSize=1)
|
||||
await client.drop_database("test_collection_names_single_socket")
|
||||
db = client.test_collection_names_single_socket
|
||||
for i in range(200):
|
||||
|
||||
@ -31,7 +31,7 @@ import warnings
|
||||
from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context
|
||||
from test.asynchronous.test_bulk import AsyncBulkTestBase
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Mapping
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@ -44,6 +44,8 @@ sys.path[0:0] = [""]
|
||||
from test import (
|
||||
unittest,
|
||||
)
|
||||
from test.asynchronous.test_bulk import AsyncBulkTestBase
|
||||
from test.asynchronous.utils_spec_runner import AsyncSpecRunner
|
||||
from test.helpers import (
|
||||
AWS_CREDS,
|
||||
AZURE_CREDS,
|
||||
@ -59,12 +61,10 @@ from test.utils import (
|
||||
OvertCommandListener,
|
||||
SpecTestCreator,
|
||||
TopologyEventListener,
|
||||
async_rs_or_single_client,
|
||||
async_wait_until,
|
||||
camel_to_snake_args,
|
||||
is_greenthread_patched,
|
||||
)
|
||||
from test.utils_spec_runner import SpecRunner
|
||||
|
||||
from bson import DatetimeMS, Decimal128, encode, json_util
|
||||
from bson.binary import UUID_SUBTYPE, Binary, UuidRepresentation
|
||||
@ -109,13 +109,12 @@ class TestAutoEncryptionOpts(AsyncPyMongoTestCase):
|
||||
@unittest.skipUnless(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is not installed")
|
||||
async def test_crypt_shared(self):
|
||||
# Test that we can pick up crypt_shared lib automatically
|
||||
client = AsyncMongoClient(
|
||||
self.simple_client(
|
||||
auto_encryption_opts=AutoEncryptionOpts(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", crypt_shared_lib_required=True
|
||||
),
|
||||
connect=False,
|
||||
)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
|
||||
@unittest.skipIf(_HAVE_PYMONGOCRYPT, "pymongocrypt is installed")
|
||||
def test_init_requires_pymongocrypt(self):
|
||||
@ -196,19 +195,16 @@ class TestAutoEncryptionOpts(AsyncPyMongoTestCase):
|
||||
|
||||
class TestClientOptions(AsyncPyMongoTestCase):
|
||||
async def test_default(self):
|
||||
client = AsyncMongoClient(connect=False)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
client = self.simple_client(connect=False)
|
||||
self.assertEqual(get_client_opts(client).auto_encryption_opts, None)
|
||||
|
||||
client = AsyncMongoClient(auto_encryption_opts=None, connect=False)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
client = self.simple_client(auto_encryption_opts=None, connect=False)
|
||||
self.assertEqual(get_client_opts(client).auto_encryption_opts, None)
|
||||
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
async def test_kwargs(self):
|
||||
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
|
||||
client = AsyncMongoClient(auto_encryption_opts=opts, connect=False)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
client = self.simple_client(auto_encryption_opts=opts, connect=False)
|
||||
self.assertEqual(get_client_opts(client).auto_encryption_opts, opts)
|
||||
|
||||
|
||||
@ -229,6 +225,34 @@ class AsyncEncryptionIntegrationTest(AsyncIntegrationTest):
|
||||
self.assertIsInstance(val, Binary)
|
||||
self.assertEqual(val.subtype, UUID_SUBTYPE)
|
||||
|
||||
def create_client_encryption(
|
||||
self,
|
||||
kms_providers: Mapping[str, Any],
|
||||
key_vault_namespace: str,
|
||||
key_vault_client: AsyncMongoClient,
|
||||
codec_options: CodecOptions,
|
||||
kms_tls_options: Optional[Mapping[str, Any]] = None,
|
||||
):
|
||||
client_encryption = AsyncClientEncryption(
|
||||
kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options
|
||||
)
|
||||
self.addAsyncCleanup(client_encryption.close)
|
||||
return client_encryption
|
||||
|
||||
@classmethod
|
||||
def unmanaged_create_client_encryption(
|
||||
cls,
|
||||
kms_providers: Mapping[str, Any],
|
||||
key_vault_namespace: str,
|
||||
key_vault_client: AsyncMongoClient,
|
||||
codec_options: CodecOptions,
|
||||
kms_tls_options: Optional[Mapping[str, Any]] = None,
|
||||
):
|
||||
client_encryption = AsyncClientEncryption(
|
||||
kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options
|
||||
)
|
||||
return client_encryption
|
||||
|
||||
|
||||
# Location of JSON test files.
|
||||
if _IS_SYNC:
|
||||
@ -260,8 +284,7 @@ def bson_data(*paths):
|
||||
|
||||
class TestClientSimple(AsyncEncryptionIntegrationTest):
|
||||
async def _test_auto_encrypt(self, opts):
|
||||
client = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
|
||||
# Create the encrypted field's data key.
|
||||
key_vault = await create_key_vault(
|
||||
@ -342,8 +365,7 @@ class TestClientSimple(AsyncEncryptionIntegrationTest):
|
||||
|
||||
async def test_use_after_close(self):
|
||||
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
|
||||
client = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
|
||||
await client.admin.command("ping")
|
||||
await client.aclose()
|
||||
@ -358,10 +380,10 @@ class TestClientSimple(AsyncEncryptionIntegrationTest):
|
||||
is_greenthread_patched(),
|
||||
"gevent and eventlet do not support POSIX-style forking.",
|
||||
)
|
||||
@async_client_context.require_sync
|
||||
async def test_fork(self):
|
||||
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
|
||||
client = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
|
||||
async def target():
|
||||
with warnings.catch_warnings():
|
||||
@ -375,8 +397,7 @@ class TestClientSimple(AsyncEncryptionIntegrationTest):
|
||||
class TestEncryptedBulkWrite(AsyncBulkTestBase, AsyncEncryptionIntegrationTest):
|
||||
async def test_upsert_uuid_standard_encrypt(self):
|
||||
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
|
||||
client = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
|
||||
options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
|
||||
encrypted_coll = client.pymongo_test.test
|
||||
@ -416,8 +437,7 @@ class TestClientMaxWireVersion(AsyncIntegrationTest):
|
||||
@async_client_context.require_version_max(4, 0, 99)
|
||||
async def test_raise_max_wire_version_error(self):
|
||||
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
|
||||
client = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
msg = "Auto-encryption requires a minimum MongoDB version of 4.2"
|
||||
with self.assertRaisesRegex(ConfigurationError, msg):
|
||||
await client.test.test.insert_one({})
|
||||
@ -430,8 +450,7 @@ class TestClientMaxWireVersion(AsyncIntegrationTest):
|
||||
|
||||
async def test_raise_unsupported_error(self):
|
||||
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
|
||||
client = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
msg = "find_raw_batches does not support auto encryption"
|
||||
with self.assertRaisesRegex(InvalidOperation, msg):
|
||||
await client.test.test.find_raw_batches({})
|
||||
@ -450,10 +469,9 @@ class TestClientMaxWireVersion(AsyncIntegrationTest):
|
||||
|
||||
class TestExplicitSimple(AsyncEncryptionIntegrationTest):
|
||||
async def test_encrypt_decrypt(self):
|
||||
client_encryption = AsyncClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS
|
||||
)
|
||||
self.addAsyncCleanup(client_encryption.close)
|
||||
# Use standard UUID representation.
|
||||
key_vault = async_client_context.client.keyvault.get_collection(
|
||||
"datakeys", codec_options=OPTS
|
||||
@ -495,10 +513,9 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest):
|
||||
self.assertEqual(decrypted_ssn, doc["ssn"])
|
||||
|
||||
async def test_validation(self):
|
||||
client_encryption = AsyncClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS
|
||||
)
|
||||
self.addAsyncCleanup(client_encryption.close)
|
||||
|
||||
msg = "value to decrypt must be a bson.binary.Binary with subtype 6"
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
@ -512,10 +529,9 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest):
|
||||
await client_encryption.encrypt("str", algo, key_id=Binary(b"123"))
|
||||
|
||||
async def test_bson_errors(self):
|
||||
client_encryption = AsyncClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS
|
||||
)
|
||||
self.addAsyncCleanup(client_encryption.close)
|
||||
|
||||
# Attempt to encrypt an unencodable object.
|
||||
unencodable_value = object()
|
||||
@ -528,7 +544,7 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest):
|
||||
|
||||
async def test_codec_options(self):
|
||||
with self.assertRaisesRegex(TypeError, "codec_options must be"):
|
||||
AsyncClientEncryption(
|
||||
self.create_client_encryption(
|
||||
KMS_PROVIDERS,
|
||||
"keyvault.datakeys",
|
||||
async_client_context.client,
|
||||
@ -536,10 +552,9 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest):
|
||||
)
|
||||
|
||||
opts = CodecOptions(uuid_representation=UuidRepresentation.JAVA_LEGACY)
|
||||
client_encryption_legacy = AsyncClientEncryption(
|
||||
client_encryption_legacy = self.create_client_encryption(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, opts
|
||||
)
|
||||
self.addAsyncCleanup(client_encryption_legacy.close)
|
||||
|
||||
# Create the encrypted field's data key.
|
||||
key_id = await client_encryption_legacy.create_data_key("local")
|
||||
@ -554,10 +569,9 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest):
|
||||
|
||||
# Encrypt the same UUID with STANDARD codec options.
|
||||
opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
|
||||
client_encryption = AsyncClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, opts
|
||||
)
|
||||
self.addAsyncCleanup(client_encryption.close)
|
||||
encrypted_standard = await client_encryption.encrypt(
|
||||
value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id
|
||||
)
|
||||
@ -573,7 +587,7 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest):
|
||||
self.assertNotEqual(await client_encryption.decrypt(encrypted_legacy), value)
|
||||
|
||||
async def test_close(self):
|
||||
client_encryption = AsyncClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS
|
||||
)
|
||||
await client_encryption.close()
|
||||
@ -589,7 +603,7 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest):
|
||||
await client_encryption.decrypt(Binary(b"", 6))
|
||||
|
||||
async def test_with_statement(self):
|
||||
async with AsyncClientEncryption(
|
||||
async with self.create_client_encryption(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS
|
||||
) as client_encryption:
|
||||
pass
|
||||
@ -613,7 +627,7 @@ KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PE
|
||||
|
||||
if _IS_SYNC:
|
||||
# TODO: Add asynchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700)
|
||||
class TestSpec(SpecRunner):
|
||||
class TestSpec(AsyncSpecRunner):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def setUpClass(cls):
|
||||
@ -811,7 +825,7 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = await async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
await cls.client.db.coll.drop()
|
||||
cls.vault = await create_key_vault(cls.client.keyvault.datakeys)
|
||||
|
||||
@ -833,10 +847,10 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
|
||||
opts = AutoEncryptionOpts(
|
||||
cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS
|
||||
)
|
||||
cls.client_encrypted = await async_rs_or_single_client(
|
||||
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client(
|
||||
auto_encryption_opts=opts, uuidRepresentation="standard"
|
||||
)
|
||||
cls.client_encryption = AsyncClientEncryption(
|
||||
cls.client_encryption = cls.unmanaged_create_client_encryption(
|
||||
cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS
|
||||
)
|
||||
|
||||
@ -923,10 +937,9 @@ class TestExternalKeyVault(AsyncEncryptionIntegrationTest):
|
||||
# Configure the encrypted field via the local schema_map option.
|
||||
schemas = {"db.coll": json_data("external", "external-schema.json")}
|
||||
if with_external_key_vault:
|
||||
key_vault_client = await async_rs_or_single_client(
|
||||
key_vault_client = await self.async_rs_or_single_client(
|
||||
username="fake-user", password="fake-pwd"
|
||||
)
|
||||
self.addAsyncCleanup(key_vault_client.close)
|
||||
else:
|
||||
key_vault_client = async_client_context.client
|
||||
opts = AutoEncryptionOpts(
|
||||
@ -936,15 +949,13 @@ class TestExternalKeyVault(AsyncEncryptionIntegrationTest):
|
||||
key_vault_client=key_vault_client,
|
||||
)
|
||||
|
||||
client_encrypted = await async_rs_or_single_client(
|
||||
client_encrypted = await self.async_rs_or_single_client(
|
||||
auto_encryption_opts=opts, uuidRepresentation="standard"
|
||||
)
|
||||
self.addAsyncCleanup(client_encrypted.close)
|
||||
|
||||
client_encryption = AsyncClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
self.kms_providers(), "keyvault.datakeys", key_vault_client, OPTS
|
||||
)
|
||||
self.addAsyncCleanup(client_encryption.close)
|
||||
|
||||
if with_external_key_vault:
|
||||
# Authentication error.
|
||||
@ -990,10 +1001,9 @@ class TestViews(AsyncEncryptionIntegrationTest):
|
||||
self.addAsyncCleanup(self.client.db.view.drop)
|
||||
|
||||
opts = AutoEncryptionOpts(self.kms_providers(), "keyvault.datakeys")
|
||||
client_encrypted = await async_rs_or_single_client(
|
||||
client_encrypted = await self.async_rs_or_single_client(
|
||||
auto_encryption_opts=opts, uuidRepresentation="standard"
|
||||
)
|
||||
self.addAsyncCleanup(client_encrypted.aclose)
|
||||
|
||||
with self.assertRaisesRegex(EncryptionError, "cannot auto encrypt a view"):
|
||||
await client_encrypted.db.view.insert_one({})
|
||||
@ -1050,17 +1060,15 @@ class TestCorpus(AsyncEncryptionIntegrationTest):
|
||||
)
|
||||
self.addAsyncCleanup(vault.drop)
|
||||
|
||||
client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addAsyncCleanup(client_encrypted.close)
|
||||
client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
|
||||
client_encryption = AsyncClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
self.kms_providers(),
|
||||
"keyvault.datakeys",
|
||||
async_client_context.client,
|
||||
OPTS,
|
||||
kms_tls_options=KMS_TLS_OPTS,
|
||||
)
|
||||
self.addAsyncCleanup(client_encryption.close)
|
||||
|
||||
corpus = self.fix_up_curpus(json_data("corpus", "corpus.json"))
|
||||
corpus_copied: SON = SON()
|
||||
@ -1203,7 +1211,7 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
||||
|
||||
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client_encrypted = await async_rs_or_single_client(
|
||||
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client(
|
||||
auto_encryption_opts=opts, event_listeners=[cls.listener]
|
||||
)
|
||||
cls.coll_encrypted = cls.client_encrypted.db.coll
|
||||
@ -1291,7 +1299,7 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
|
||||
"gcp": GCP_CREDS,
|
||||
"kmip": KMIP_CREDS,
|
||||
}
|
||||
self.client_encryption = AsyncClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
kms_providers=kms_providers,
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
key_vault_client=async_client_context.client,
|
||||
@ -1303,7 +1311,7 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
|
||||
kms_providers_invalid["azure"]["identityPlatformEndpoint"] = "doesnotexist.invalid:443"
|
||||
kms_providers_invalid["gcp"]["endpoint"] = "doesnotexist.invalid:443"
|
||||
kms_providers_invalid["kmip"]["endpoint"] = "doesnotexist.local:5698"
|
||||
self.client_encryption_invalid = AsyncClientEncryption(
|
||||
self.client_encryption_invalid = self.create_client_encryption(
|
||||
kms_providers=kms_providers_invalid,
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
key_vault_client=async_client_context.client,
|
||||
@ -1484,7 +1492,7 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
|
||||
await self.client_encryption.create_data_key("kmip", key)
|
||||
|
||||
|
||||
class AzureGCPEncryptionTestMixin:
|
||||
class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
|
||||
DEK = None
|
||||
KMS_PROVIDER_MAP = None
|
||||
KEYVAULT_DB = "keyvault"
|
||||
@ -1496,7 +1504,7 @@ class AzureGCPEncryptionTestMixin:
|
||||
await create_key_vault(keyvault, self.DEK)
|
||||
|
||||
async def _test_explicit(self, expectation):
|
||||
client_encryption = AsyncClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
|
||||
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
|
||||
async_client_context.client,
|
||||
@ -1525,7 +1533,7 @@ class AzureGCPEncryptionTestMixin:
|
||||
)
|
||||
|
||||
insert_listener = AllowListEventListener("insert")
|
||||
client = await async_rs_or_single_client(
|
||||
client = await self.async_rs_or_single_client(
|
||||
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
|
||||
)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
@ -1604,19 +1612,17 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationT
|
||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#deadlock-tests
|
||||
class TestDeadlockProse(AsyncEncryptionIntegrationTest):
|
||||
async def asyncSetUp(self):
|
||||
self.client_test = await async_rs_or_single_client(
|
||||
self.client_test = await self.async_rs_or_single_client(
|
||||
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
|
||||
)
|
||||
self.addAsyncCleanup(self.client_test.aclose)
|
||||
|
||||
self.client_keyvault_listener = OvertCommandListener()
|
||||
self.client_keyvault = await async_rs_or_single_client(
|
||||
self.client_keyvault = await self.async_rs_or_single_client(
|
||||
maxPoolSize=1,
|
||||
readConcernLevel="majority",
|
||||
w="majority",
|
||||
event_listeners=[self.client_keyvault_listener],
|
||||
)
|
||||
self.addAsyncCleanup(self.client_keyvault.aclose)
|
||||
|
||||
await self.client_test.keyvault.datakeys.drop()
|
||||
await self.client_test.db.coll.drop()
|
||||
@ -1629,7 +1635,7 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
|
||||
codec_options=OPTS,
|
||||
)
|
||||
|
||||
client_encryption = AsyncClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
kms_providers={"local": {"key": LOCAL_MASTER_KEY}},
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
key_vault_client=self.client_test,
|
||||
@ -1645,7 +1651,7 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
|
||||
self.optargs = ({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
|
||||
|
||||
async def _run_test(self, max_pool_size, auto_encryption_opts):
|
||||
client_encrypted = await async_rs_or_single_client(
|
||||
client_encrypted = await self.async_rs_or_single_client(
|
||||
readConcernLevel="majority",
|
||||
w="majority",
|
||||
maxPoolSize=max_pool_size,
|
||||
@ -1663,8 +1669,6 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
|
||||
result = await client_encrypted.db.coll.find_one({"_id": 0})
|
||||
self.assertEqual(result, {"_id": 0, "encrypted": "string0"})
|
||||
|
||||
self.addAsyncCleanup(client_encrypted.close)
|
||||
|
||||
async def test_case_1(self):
|
||||
await self._run_test(
|
||||
max_pool_size=1,
|
||||
@ -1840,7 +1844,7 @@ class TestDecryptProse(AsyncEncryptionIntegrationTest):
|
||||
await create_key_vault(self.client.keyvault.datakeys)
|
||||
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
|
||||
|
||||
self.client_encryption = AsyncClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
kms_providers_map, "keyvault.datakeys", self.client, CodecOptions()
|
||||
)
|
||||
keyID = await self.client_encryption.create_data_key("local")
|
||||
@ -1855,10 +1859,9 @@ class TestDecryptProse(AsyncEncryptionIntegrationTest):
|
||||
key_vault_namespace="keyvault.datakeys", kms_providers=kms_providers_map
|
||||
)
|
||||
self.listener = AllowListEventListener("aggregate")
|
||||
self.encrypted_client = await async_rs_or_single_client(
|
||||
self.encrypted_client = await self.async_rs_or_single_client(
|
||||
auto_encryption_opts=opts, retryReads=False, event_listeners=[self.listener]
|
||||
)
|
||||
self.addAsyncCleanup(self.encrypted_client.close)
|
||||
|
||||
async def test_01_command_error(self):
|
||||
async with self.fail_point(
|
||||
@ -1935,8 +1938,7 @@ class TestBypassSpawningMongocryptdProse(AsyncEncryptionIntegrationTest):
|
||||
"--port=27027",
|
||||
],
|
||||
)
|
||||
client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addAsyncCleanup(client_encrypted.close)
|
||||
client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
with self.assertRaisesRegex(EncryptionError, "Timeout"):
|
||||
await client_encrypted.db.coll.insert_one({"encrypted": "test"})
|
||||
|
||||
@ -1950,11 +1952,10 @@ class TestBypassSpawningMongocryptdProse(AsyncEncryptionIntegrationTest):
|
||||
"--port=27027",
|
||||
],
|
||||
)
|
||||
client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addAsyncCleanup(client_encrypted.aclose)
|
||||
client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
await client_encrypted.db.coll.insert_one({"unencrypted": "test"})
|
||||
# Validate that mongocryptd was not spawned:
|
||||
mongocryptd_client = AsyncMongoClient(
|
||||
mongocryptd_client = self.simple_client(
|
||||
"mongodb://localhost:27027/?serverSelectionTimeoutMS=500"
|
||||
)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
@ -1978,15 +1979,13 @@ class TestBypassSpawningMongocryptdProse(AsyncEncryptionIntegrationTest):
|
||||
],
|
||||
crypt_shared_lib_required=True,
|
||||
)
|
||||
client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addAsyncCleanup(client_encrypted.aclose)
|
||||
client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
await client_encrypted.db.coll.drop()
|
||||
await client_encrypted.db.coll.insert_one({"encrypted": "test"})
|
||||
self.assertEncrypted((await async_client_context.client.db.coll.find_one({}))["encrypted"])
|
||||
no_mongocryptd_client = AsyncMongoClient(
|
||||
no_mongocryptd_client = self.simple_client(
|
||||
host="mongodb://localhost:47021/db?serverSelectionTimeoutMS=1000"
|
||||
)
|
||||
self.addAsyncCleanup(no_mongocryptd_client.aclose)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
await no_mongocryptd_client.db.command("ping")
|
||||
|
||||
@ -2020,8 +2019,7 @@ class TestBypassSpawningMongocryptdProse(AsyncEncryptionIntegrationTest):
|
||||
mongocryptd_uri="mongodb://localhost:47021",
|
||||
crypt_shared_lib_required=False,
|
||||
)
|
||||
client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addAsyncCleanup(client_encrypted.aclose)
|
||||
client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
await client_encrypted.db.coll.drop()
|
||||
await client_encrypted.db.coll.insert_one({"encrypted": "test"})
|
||||
server.shutdown()
|
||||
@ -2035,10 +2033,9 @@ class TestKmsTLSProse(AsyncEncryptionIntegrationTest):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.patch_system_certs(CA_PEM)
|
||||
self.client_encrypted = AsyncClientEncryption(
|
||||
self.client_encrypted = self.create_client_encryption(
|
||||
{"aws": AWS_CREDS}, "keyvault.datakeys", self.client, OPTS
|
||||
)
|
||||
self.addAsyncCleanup(self.client_encrypted.close)
|
||||
|
||||
async def test_invalid_kms_certificate_expired(self):
|
||||
key = {
|
||||
@ -2083,36 +2080,32 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest):
|
||||
"gcp": {"tlsCAFile": CA_PEM},
|
||||
"kmip": {"tlsCAFile": CA_PEM},
|
||||
}
|
||||
self.client_encryption_no_client_cert = AsyncClientEncryption(
|
||||
self.client_encryption_no_client_cert = self.create_client_encryption(
|
||||
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only
|
||||
)
|
||||
self.addAsyncCleanup(self.client_encryption_no_client_cert.close)
|
||||
# 2, same providers as above but with tlsCertificateKeyFile.
|
||||
kms_tls_opts = copy.deepcopy(kms_tls_opts_ca_only)
|
||||
for p in kms_tls_opts:
|
||||
kms_tls_opts[p]["tlsCertificateKeyFile"] = CLIENT_PEM
|
||||
self.client_encryption_with_tls = AsyncClientEncryption(
|
||||
self.client_encryption_with_tls = self.create_client_encryption(
|
||||
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts
|
||||
)
|
||||
self.addAsyncCleanup(self.client_encryption_with_tls.close)
|
||||
# 3, update endpoints to expired host.
|
||||
providers: dict = copy.deepcopy(providers)
|
||||
providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9000"
|
||||
providers["gcp"]["endpoint"] = "127.0.0.1:9000"
|
||||
providers["kmip"]["endpoint"] = "127.0.0.1:9000"
|
||||
self.client_encryption_expired = AsyncClientEncryption(
|
||||
self.client_encryption_expired = self.create_client_encryption(
|
||||
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only
|
||||
)
|
||||
self.addAsyncCleanup(self.client_encryption_expired.close)
|
||||
# 3, update endpoints to invalid host.
|
||||
providers: dict = copy.deepcopy(providers)
|
||||
providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9001"
|
||||
providers["gcp"]["endpoint"] = "127.0.0.1:9001"
|
||||
providers["kmip"]["endpoint"] = "127.0.0.1:9001"
|
||||
self.client_encryption_invalid_hostname = AsyncClientEncryption(
|
||||
self.client_encryption_invalid_hostname = self.create_client_encryption(
|
||||
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only
|
||||
)
|
||||
self.addAsyncCleanup(self.client_encryption_invalid_hostname.close)
|
||||
# Errors when client has no cert, some examples:
|
||||
# [SSL: TLSV13_ALERT_CERTIFICATE_REQUIRED] tlsv13 alert certificate required (_ssl.c:2623)
|
||||
self.cert_error = (
|
||||
@ -2150,7 +2143,7 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest):
|
||||
"gcp:with_tls": with_cert,
|
||||
"kmip:with_tls": with_cert,
|
||||
}
|
||||
self.client_encryption_with_names = AsyncClientEncryption(
|
||||
self.client_encryption_with_names = self.create_client_encryption(
|
||||
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_4
|
||||
)
|
||||
|
||||
@ -2232,10 +2225,9 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest):
|
||||
async def test_05_tlsDisableOCSPEndpointCheck_is_permitted(self):
|
||||
providers = {"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}}
|
||||
options = {"aws": {"tlsDisableOCSPEndpointCheck": True}}
|
||||
encryption = AsyncClientEncryption(
|
||||
encryption = self.create_client_encryption(
|
||||
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options
|
||||
)
|
||||
self.addAsyncCleanup(encryption.close)
|
||||
ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"]
|
||||
if not hasattr(ctx, "check_ocsp_endpoint"):
|
||||
raise self.skipTest("OCSP not enabled")
|
||||
@ -2285,7 +2277,7 @@ class TestUniqueIndexOnKeyAltNamesProse(AsyncEncryptionIntegrationTest):
|
||||
self.client = async_client_context.client
|
||||
await create_key_vault(self.client.keyvault.datakeys)
|
||||
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
|
||||
self.client_encryption = AsyncClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
kms_providers_map, "keyvault.datakeys", self.client, CodecOptions()
|
||||
)
|
||||
self.def_key_id = await self.client_encryption.create_data_key(
|
||||
@ -2327,17 +2319,15 @@ class TestExplicitQueryableEncryption(AsyncEncryptionIntegrationTest):
|
||||
key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document)
|
||||
self.addCleanup(key_vault.drop)
|
||||
self.key_vault_client = self.client
|
||||
self.client_encryption = AsyncClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
{"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS
|
||||
)
|
||||
self.addAsyncCleanup(self.client_encryption.close)
|
||||
opts = AutoEncryptionOpts(
|
||||
{"local": {"key": LOCAL_MASTER_KEY}},
|
||||
key_vault.full_name,
|
||||
bypass_query_analysis=True,
|
||||
)
|
||||
self.encrypted_client = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addAsyncCleanup(self.encrypted_client.aclose)
|
||||
self.encrypted_client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
|
||||
async def test_01_insert_encrypted_indexed_and_find(self):
|
||||
val = "encrypted indexed value"
|
||||
@ -2464,14 +2454,13 @@ class TestRewrapWithSeparateClientEncryption(AsyncEncryptionIntegrationTest):
|
||||
await self.client.keyvault.drop_collection("datakeys")
|
||||
|
||||
# Step 2. Create a ``AsyncClientEncryption`` object named ``client_encryption1``
|
||||
client_encryption1 = AsyncClientEncryption(
|
||||
client_encryption1 = self.create_client_encryption(
|
||||
key_vault_client=self.client,
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
kms_providers=ALL_KMS_PROVIDERS,
|
||||
kms_tls_options=KMS_TLS_OPTS,
|
||||
codec_options=OPTS,
|
||||
)
|
||||
self.addAsyncCleanup(client_encryption1.close)
|
||||
|
||||
# Step 3. Call ``client_encryption1.create_data_key`` with ``src_provider``.
|
||||
key_id = await client_encryption1.create_data_key(
|
||||
@ -2484,16 +2473,14 @@ class TestRewrapWithSeparateClientEncryption(AsyncEncryptionIntegrationTest):
|
||||
)
|
||||
|
||||
# Step 5. Create a ``AsyncClientEncryption`` object named ``client_encryption2``
|
||||
client2 = await async_rs_or_single_client()
|
||||
self.addAsyncCleanup(client2.aclose)
|
||||
client_encryption2 = AsyncClientEncryption(
|
||||
client2 = await self.async_rs_or_single_client()
|
||||
client_encryption2 = self.create_client_encryption(
|
||||
key_vault_client=client2,
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
kms_providers=ALL_KMS_PROVIDERS,
|
||||
kms_tls_options=KMS_TLS_OPTS,
|
||||
codec_options=OPTS,
|
||||
)
|
||||
self.addAsyncCleanup(client_encryption2.close)
|
||||
|
||||
# Step 6. Call ``client_encryption2.rewrap_many_data_key`` with an empty ``filter``.
|
||||
rewrap_many_data_key_result = await client_encryption2.rewrap_many_data_key(
|
||||
@ -2528,7 +2515,7 @@ class TestOnDemandAWSCredentials(AsyncEncryptionIntegrationTest):
|
||||
|
||||
@unittest.skipIf(any(AWS_CREDS.values()), "AWS environment credentials are set")
|
||||
async def test_01_failure(self):
|
||||
self.client_encryption = AsyncClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
kms_providers={"aws": {}},
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
key_vault_client=async_client_context.client,
|
||||
@ -2539,7 +2526,7 @@ class TestOnDemandAWSCredentials(AsyncEncryptionIntegrationTest):
|
||||
|
||||
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
|
||||
async def test_02_success(self):
|
||||
self.client_encryption = AsyncClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
kms_providers={"aws": {}},
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
key_vault_client=async_client_context.client,
|
||||
@ -2559,8 +2546,7 @@ class TestQueryableEncryptionDocsExample(AsyncEncryptionIntegrationTest):
|
||||
# AsyncMongoClient to use in testing that handles auth/tls/etc,
|
||||
# and cleanup.
|
||||
async def AsyncMongoClient(**kwargs):
|
||||
c = await async_rs_or_single_client(**kwargs)
|
||||
self.addAsyncCleanup(c.aclose)
|
||||
c = await self.async_rs_or_single_client(**kwargs)
|
||||
return c
|
||||
|
||||
# Drop data from prior test runs.
|
||||
@ -2571,7 +2557,7 @@ class TestQueryableEncryptionDocsExample(AsyncEncryptionIntegrationTest):
|
||||
|
||||
# Create two data keys.
|
||||
key_vault_client = await AsyncMongoClient()
|
||||
client_encryption = AsyncClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
kms_providers_map, "keyvault.datakeys", key_vault_client, CodecOptions()
|
||||
)
|
||||
key1_id = await client_encryption.create_data_key("local")
|
||||
@ -2652,18 +2638,16 @@ class TestRangeQueryProse(AsyncEncryptionIntegrationTest):
|
||||
key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document)
|
||||
self.addCleanup(key_vault.drop)
|
||||
self.key_vault_client = self.client
|
||||
self.client_encryption = AsyncClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
{"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS
|
||||
)
|
||||
self.addAsyncCleanup(self.client_encryption.close)
|
||||
opts = AutoEncryptionOpts(
|
||||
{"local": {"key": LOCAL_MASTER_KEY}},
|
||||
key_vault.full_name,
|
||||
bypass_query_analysis=True,
|
||||
)
|
||||
self.encrypted_client = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.encrypted_client = await self.async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.db = self.encrypted_client.db
|
||||
self.addAsyncCleanup(self.encrypted_client.aclose)
|
||||
|
||||
async def run_expression_find(
|
||||
self, name, expression, expected_elems, range_opts, use_expr=False, key_id=None
|
||||
@ -2860,10 +2844,9 @@ class TestRangeQueryDefaultsProse(AsyncEncryptionIntegrationTest):
|
||||
await super().asyncSetUp()
|
||||
await self.client.drop_database(self.db)
|
||||
self.key_vault_client = self.client
|
||||
self.client_encryption = AsyncClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
{"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys", self.key_vault_client, OPTS
|
||||
)
|
||||
self.addAsyncCleanup(self.client_encryption.close)
|
||||
self.key_id = await self.client_encryption.create_data_key("local")
|
||||
opts = RangeOpts(min=0, max=1000)
|
||||
self.payload_defaults = await self.client_encryption.encrypt(
|
||||
@ -2896,13 +2879,12 @@ class TestAutomaticDecryptionKeys(AsyncEncryptionIntegrationTest):
|
||||
await self.client.drop_database(self.db)
|
||||
self.key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document)
|
||||
self.addAsyncCleanup(self.key_vault.drop)
|
||||
self.client_encryption = AsyncClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
{"local": {"key": LOCAL_MASTER_KEY}},
|
||||
self.key_vault.full_name,
|
||||
self.client,
|
||||
OPTS,
|
||||
)
|
||||
self.addAsyncCleanup(self.client_encryption.close)
|
||||
|
||||
async def test_01_simple_create(self):
|
||||
coll, _ = await self.client_encryption.create_encrypted_collection(
|
||||
@ -3118,10 +3100,9 @@ class TestNoSessionsSupport(AsyncEncryptionIntegrationTest):
|
||||
|
||||
async def asyncSetUp(self) -> None:
|
||||
self.listener = OvertCommandListener()
|
||||
self.mongocryptd_client = AsyncMongoClient(
|
||||
self.mongocryptd_client = self.simple_client(
|
||||
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]
|
||||
)
|
||||
self.addAsyncCleanup(self.mongocryptd_client.aclose)
|
||||
|
||||
hello = await self.mongocryptd_client.db.command("hello")
|
||||
self.assertNotIn("logicalSessionTimeoutMinutes", hello)
|
||||
|
||||
@ -33,7 +33,7 @@ from pymongo.asynchronous.database import AsyncDatabase
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.utils import EventListener, async_rs_or_single_client
|
||||
from test.utils import EventListener
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from gridfs.asynchronous.grid_file import (
|
||||
@ -792,7 +792,7 @@ Bye"""
|
||||
await outfile.readchunk()
|
||||
|
||||
async def test_grid_in_lazy_connect(self):
|
||||
client = AsyncMongoClient("badhost", connect=False, serverSelectionTimeoutMS=10)
|
||||
client = self.simple_client("badhost", connect=False, serverSelectionTimeoutMS=10)
|
||||
fs = client.db.fs
|
||||
infile = AsyncGridIn(fs, file_id=-1, chunk_size=1)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
@ -803,7 +803,7 @@ Bye"""
|
||||
async def test_unacknowledged(self):
|
||||
# w=0 is prohibited.
|
||||
with self.assertRaises(ConfigurationError):
|
||||
AsyncGridIn((await async_rs_or_single_client(w=0)).pymongo_test.fs)
|
||||
AsyncGridIn((await self.async_rs_or_single_client(w=0)).pymongo_test.fs)
|
||||
|
||||
async def test_survive_cursor_not_found(self):
|
||||
# By default the find command returns 101 documents in the first batch.
|
||||
@ -811,7 +811,7 @@ Bye"""
|
||||
chunk_size = 1024
|
||||
data = b"d" * (102 * chunk_size)
|
||||
listener = EventListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
db = client.pymongo_test
|
||||
async with AsyncGridIn(db.fs, chunk_size=chunk_size) as infile:
|
||||
await infile.write(data)
|
||||
|
||||
@ -16,7 +16,6 @@ from __future__ import annotations
|
||||
import os
|
||||
from test import unittest
|
||||
from test.asynchronous import AsyncIntegrationTest
|
||||
from test.utils import async_single_client
|
||||
from unittest.mock import patch
|
||||
|
||||
from bson import json_util
|
||||
@ -86,7 +85,7 @@ class TestLogger(AsyncIntegrationTest):
|
||||
self.assertEqual(last_3_bytes, str_to_repeat)
|
||||
|
||||
async def test_logging_without_listeners(self):
|
||||
c = await async_single_client()
|
||||
c = await self.async_single_client()
|
||||
self.assertEqual(len(c._event_listeners.event_listeners()), 0)
|
||||
with self.assertLogs("pymongo.connection", level="DEBUG") as cm:
|
||||
await c.db.test.insert_one({"x": "1"})
|
||||
|
||||
@ -31,8 +31,6 @@ from test.asynchronous import (
|
||||
)
|
||||
from test.utils import (
|
||||
EventListener,
|
||||
async_rs_or_single_client,
|
||||
async_single_client,
|
||||
async_wait_until,
|
||||
)
|
||||
|
||||
@ -57,7 +55,7 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest):
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.listener = EventListener()
|
||||
cls.client = await async_rs_or_single_client(
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(
|
||||
event_listeners=[cls.listener], retryWrites=False
|
||||
)
|
||||
|
||||
@ -407,7 +405,7 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest):
|
||||
@async_client_context.require_secondaries_count(1)
|
||||
async def test_not_primary_error(self):
|
||||
address = next(iter(await async_client_context.client.secondaries))
|
||||
client = await async_single_client(*address, event_listeners=[self.listener])
|
||||
client = await self.async_single_client(*address, event_listeners=[self.listener])
|
||||
# Clear authentication command results from the listener.
|
||||
await client.admin.command("ping")
|
||||
self.listener.reset()
|
||||
@ -1146,7 +1144,7 @@ class AsyncTestGlobalListener(AsyncIntegrationTest):
|
||||
# We plan to call register(), which internally modifies _LISTENERS.
|
||||
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
|
||||
monitoring.register(cls.listener)
|
||||
cls.client = await async_single_client()
|
||||
cls.client = await cls.unmanaged_async_single_client()
|
||||
# Get one (authenticated) socket in the pool.
|
||||
await cls.client.pymongo_test.command("ping")
|
||||
|
||||
|
||||
@ -36,9 +36,7 @@ from test.asynchronous import (
|
||||
from test.utils import (
|
||||
EventListener,
|
||||
ExceptionCatchingThread,
|
||||
async_rs_or_single_client,
|
||||
async_wait_until,
|
||||
rs_or_single_client,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
@ -90,7 +88,7 @@ class TestSession(AsyncIntegrationTest):
|
||||
await super()._setup_class()
|
||||
# Create a second client so we can make sure clients cannot share
|
||||
# sessions.
|
||||
cls.client2 = await async_rs_or_single_client()
|
||||
cls.client2 = await cls.unmanaged_async_rs_or_single_client()
|
||||
|
||||
# Redact no commands, so we can test user-admin commands have "lsid".
|
||||
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
|
||||
@ -105,7 +103,7 @@ class TestSession(AsyncIntegrationTest):
|
||||
async def asyncSetUp(self):
|
||||
self.listener = SessionTestListener()
|
||||
self.session_checker_listener = SessionTestListener()
|
||||
self.client = await async_rs_or_single_client(
|
||||
self.client = await self.async_rs_or_single_client(
|
||||
event_listeners=[self.listener, self.session_checker_listener]
|
||||
)
|
||||
self.addAsyncCleanup(self.client.close)
|
||||
@ -202,7 +200,7 @@ class TestSession(AsyncIntegrationTest):
|
||||
failures = 0
|
||||
for _ in range(5):
|
||||
listener = EventListener()
|
||||
client = async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
|
||||
client = self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
|
||||
cursor = client.db.test.find({})
|
||||
ops: List[Tuple[Callable, List[Any]]] = [
|
||||
(client.db.test.find_one, [{"_id": 1}]),
|
||||
@ -285,7 +283,7 @@ class TestSession(AsyncIntegrationTest):
|
||||
async def test_end_sessions(self):
|
||||
# Use a new client so that the tearDown hook does not error.
|
||||
listener = SessionTestListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
# Start many sessions.
|
||||
sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)]
|
||||
for s in sessions:
|
||||
@ -789,8 +787,7 @@ class TestSession(AsyncIntegrationTest):
|
||||
async def test_unacknowledged_writes(self):
|
||||
# Ensure the collection exists.
|
||||
await self.client.pymongo_test.test_unacked_writes.insert_one({})
|
||||
client = await async_rs_or_single_client(w=0, event_listeners=[self.listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_or_single_client(w=0, event_listeners=[self.listener])
|
||||
db = client.pymongo_test
|
||||
coll = db.test_unacked_writes
|
||||
ops: list = [
|
||||
@ -838,7 +835,7 @@ class TestCausalConsistency(AsyncUnitTest):
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
cls.listener = SessionTestListener()
|
||||
cls.client = await async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
@ -1153,10 +1150,9 @@ class TestClusterTime(AsyncIntegrationTest):
|
||||
async def test_cluster_time(self):
|
||||
listener = SessionTestListener()
|
||||
# Prevent heartbeats from updating $clusterTime between operations.
|
||||
client = await async_rs_or_single_client(
|
||||
client = await self.async_rs_or_single_client(
|
||||
event_listeners=[listener], heartbeatFrequencyMS=999999
|
||||
)
|
||||
self.addAsyncCleanup(client.close)
|
||||
collection = client.pymongo_test.collection
|
||||
# Prepare for tests of find() and aggregate().
|
||||
await collection.insert_many([{} for _ in range(10)])
|
||||
|
||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from io import BytesIO
|
||||
from test.asynchronous.utils_spec_runner import AsyncSpecRunner
|
||||
|
||||
from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket
|
||||
|
||||
@ -25,8 +26,6 @@ sys.path[0:0] = [""]
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||
from test.utils import (
|
||||
OvertCommandListener,
|
||||
async_rs_client,
|
||||
async_single_client,
|
||||
wait_until,
|
||||
)
|
||||
from typing import List
|
||||
@ -59,7 +58,18 @@ _IS_SYNC = False
|
||||
UNPIN_TEST_MAX_ATTEMPTS = 50
|
||||
|
||||
|
||||
class TestTransactions(AsyncIntegrationTest):
|
||||
class AsyncTransactionsBase(AsyncSpecRunner):
|
||||
def maybe_skip_scenario(self, test):
|
||||
super().maybe_skip_scenario(test)
|
||||
if (
|
||||
"secondary" in self.id()
|
||||
and not async_client_context.is_mongos
|
||||
and not async_client_context.has_secondaries
|
||||
):
|
||||
raise unittest.SkipTest("No secondaries")
|
||||
|
||||
|
||||
class TestTransactions(AsyncTransactionsBase):
|
||||
RUN_ON_SERVERLESS = True
|
||||
|
||||
@async_client_context.require_transactions
|
||||
@ -92,8 +102,7 @@ class TestTransactions(AsyncIntegrationTest):
|
||||
@async_client_context.require_transactions
|
||||
async def test_transaction_write_concern_override(self):
|
||||
"""Test txn overrides Client/Database/Collection write_concern."""
|
||||
client = await async_rs_client(w=0)
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_client(w=0)
|
||||
db = client.test
|
||||
coll = db.test
|
||||
await coll.insert_one({})
|
||||
@ -150,12 +159,13 @@ class TestTransactions(AsyncIntegrationTest):
|
||||
async def test_unpin_for_next_transaction(self):
|
||||
# Increase localThresholdMS and wait until both nodes are discovered
|
||||
# to avoid false positives.
|
||||
client = await async_rs_client(async_client_context.mongos_seeds(), localThresholdMS=1000)
|
||||
client = await self.async_rs_client(
|
||||
async_client_context.mongos_seeds(), localThresholdMS=1000
|
||||
)
|
||||
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
||||
coll = client.test.test
|
||||
# Create the collection.
|
||||
await coll.insert_one({})
|
||||
self.addAsyncCleanup(client.close)
|
||||
async with client.start_session() as s:
|
||||
# Session is pinned to Mongos.
|
||||
async with await s.start_transaction():
|
||||
@ -178,12 +188,13 @@ class TestTransactions(AsyncIntegrationTest):
|
||||
async def test_unpin_for_non_transaction_operation(self):
|
||||
# Increase localThresholdMS and wait until both nodes are discovered
|
||||
# to avoid false positives.
|
||||
client = await async_rs_client(async_client_context.mongos_seeds(), localThresholdMS=1000)
|
||||
client = await self.async_rs_client(
|
||||
async_client_context.mongos_seeds(), localThresholdMS=1000
|
||||
)
|
||||
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
||||
coll = client.test.test
|
||||
# Create the collection.
|
||||
await coll.insert_one({})
|
||||
self.addAsyncCleanup(client.close)
|
||||
async with client.start_session() as s:
|
||||
# Session is pinned to Mongos.
|
||||
async with await s.start_transaction():
|
||||
@ -307,11 +318,10 @@ class TestTransactions(AsyncIntegrationTest):
|
||||
# Start a transaction with a batch of operations that needs to be
|
||||
# split.
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_client(event_listeners=[listener])
|
||||
client = await self.async_rs_client(event_listeners=[listener])
|
||||
coll = client[self.db.name].test
|
||||
await coll.delete_many({})
|
||||
listener.reset()
|
||||
self.addAsyncCleanup(client.close)
|
||||
self.addAsyncCleanup(coll.drop)
|
||||
large_str = "\0" * (1 * 1024 * 1024)
|
||||
ops: List[InsertOne[RawBSONDocument]] = [
|
||||
@ -336,8 +346,7 @@ class TestTransactions(AsyncIntegrationTest):
|
||||
|
||||
@async_client_context.require_transactions
|
||||
async def test_transaction_direct_connection(self):
|
||||
client = await async_single_client()
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_single_client()
|
||||
coll = client.pymongo_test.test
|
||||
|
||||
# Make sure the collection exists.
|
||||
@ -393,14 +402,16 @@ class PatchSessionTimeout:
|
||||
client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.real_timeout
|
||||
|
||||
|
||||
class TestTransactionsConvenientAPI(AsyncIntegrationTest):
|
||||
class TestTransactionsConvenientAPI(AsyncTransactionsBase):
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.mongos_clients = []
|
||||
if async_client_context.supports_transactions():
|
||||
for address in async_client_context.mongoses:
|
||||
cls.mongos_clients.append(await async_single_client("{}:{}".format(*address)))
|
||||
cls.mongos_clients.append(
|
||||
await cls.unmanaged_async_single_client("{}:{}".format(*address))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
@ -450,8 +461,7 @@ class TestTransactionsConvenientAPI(AsyncIntegrationTest):
|
||||
@async_client_context.require_transactions
|
||||
async def test_callback_not_retried_after_timeout(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_client(event_listeners=[listener])
|
||||
coll = client[self.db.name].test
|
||||
|
||||
async def callback(session):
|
||||
@ -479,8 +489,7 @@ class TestTransactionsConvenientAPI(AsyncIntegrationTest):
|
||||
@async_client_context.require_transactions
|
||||
async def test_callback_not_retried_after_commit_timeout(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_client(event_listeners=[listener])
|
||||
coll = client[self.db.name].test
|
||||
|
||||
async def callback(session):
|
||||
@ -514,8 +523,7 @@ class TestTransactionsConvenientAPI(AsyncIntegrationTest):
|
||||
@async_client_context.require_transactions
|
||||
async def test_commit_not_retried_after_timeout(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
client = await self.async_rs_client(event_listeners=[listener])
|
||||
coll = client[self.db.name].test
|
||||
|
||||
async def callback(session):
|
||||
|
||||
@ -25,7 +25,6 @@ from test.utils import (
|
||||
EventListener,
|
||||
OvertCommandListener,
|
||||
ServerAndTopologyEventListener,
|
||||
async_rs_client,
|
||||
camel_to_snake,
|
||||
camel_to_snake_args,
|
||||
parse_spec_options,
|
||||
@ -101,6 +100,8 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
for client in cls.mongos_clients:
|
||||
await client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
def setUp(self):
|
||||
@ -527,7 +528,7 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
host = async_client_context.MULTI_MONGOS_LB_URI
|
||||
elif async_client_context.is_mongos:
|
||||
host = async_client_context.mongos_seeds()
|
||||
client = await async_rs_client(
|
||||
client = await self.async_rs_client(
|
||||
h=host, event_listeners=[listener, pool_listener, server_listener], **client_options
|
||||
)
|
||||
self.scenario_client = client
|
||||
|
||||
@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from test import PyMongoTestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -36,7 +37,7 @@ from pymongo.uri_parser import parse_uri
|
||||
pytestmark = pytest.mark.auth_aws
|
||||
|
||||
|
||||
class TestAuthAWS(unittest.TestCase):
|
||||
class TestAuthAWS(PyMongoTestCase):
|
||||
uri: str
|
||||
|
||||
@classmethod
|
||||
@ -69,7 +70,7 @@ class TestAuthAWS(unittest.TestCase):
|
||||
self.skipTest("Not testing cached credentials")
|
||||
|
||||
# Make a connection to ensure that we enable caching.
|
||||
client = MongoClient(self.uri)
|
||||
client = self.simple_client(self.uri)
|
||||
client.get_database().test.find_one()
|
||||
client.close()
|
||||
|
||||
@ -79,7 +80,7 @@ class TestAuthAWS(unittest.TestCase):
|
||||
auth.set_cached_credentials(None)
|
||||
self.assertEqual(auth.get_cached_credentials(), None)
|
||||
|
||||
client = MongoClient(self.uri)
|
||||
client = self.simple_client(self.uri)
|
||||
client.get_database().test.find_one()
|
||||
client.close()
|
||||
return auth.get_cached_credentials()
|
||||
@ -90,8 +91,7 @@ class TestAuthAWS(unittest.TestCase):
|
||||
|
||||
def test_cache_about_to_expire(self):
|
||||
creds = self.setup_cache()
|
||||
client = MongoClient(self.uri)
|
||||
self.addCleanup(client.close)
|
||||
client = self.simple_client(self.uri)
|
||||
|
||||
# Make the creds about to expire.
|
||||
creds = auth.get_cached_credentials()
|
||||
@ -107,8 +107,7 @@ class TestAuthAWS(unittest.TestCase):
|
||||
def test_poisoned_cache(self):
|
||||
creds = self.setup_cache()
|
||||
|
||||
client = MongoClient(self.uri)
|
||||
self.addCleanup(client.close)
|
||||
client = self.simple_client(self.uri)
|
||||
|
||||
# Poison the creds with invalid password.
|
||||
assert creds is not None
|
||||
@ -130,8 +129,7 @@ class TestAuthAWS(unittest.TestCase):
|
||||
self.assertIsNotNone(creds)
|
||||
os.environ.copy()
|
||||
|
||||
client = MongoClient(self.uri)
|
||||
self.addCleanup(client.close)
|
||||
client = self.simple_client(self.uri)
|
||||
|
||||
client.get_database().test.find_one()
|
||||
|
||||
@ -149,8 +147,7 @@ class TestAuthAWS(unittest.TestCase):
|
||||
|
||||
auth.set_cached_credentials(None)
|
||||
|
||||
client2 = MongoClient(self.uri)
|
||||
self.addCleanup(client2.close)
|
||||
client2 = self.simple_client(self.uri)
|
||||
|
||||
with patch.dict("os.environ", mock_env):
|
||||
self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo")
|
||||
@ -166,8 +163,7 @@ class TestAuthAWS(unittest.TestCase):
|
||||
if creds.token:
|
||||
mock_env["AWS_SESSION_TOKEN"] = creds.token
|
||||
|
||||
client = MongoClient(self.uri)
|
||||
self.addCleanup(client.close)
|
||||
client = self.simple_client(self.uri)
|
||||
|
||||
with patch.dict(os.environ, mock_env):
|
||||
self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], creds.username)
|
||||
@ -177,22 +173,19 @@ class TestAuthAWS(unittest.TestCase):
|
||||
|
||||
mock_env["AWS_ACCESS_KEY_ID"] = "foo"
|
||||
|
||||
client2 = MongoClient(self.uri)
|
||||
self.addCleanup(client2.close)
|
||||
client2 = self.simple_client(self.uri)
|
||||
|
||||
with patch.dict("os.environ", mock_env), self.assertRaises(OperationFailure):
|
||||
self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo")
|
||||
client2.get_database().test.find_one()
|
||||
|
||||
|
||||
class TestAWSLambdaExamples(unittest.TestCase):
|
||||
class TestAWSLambdaExamples(PyMongoTestCase):
|
||||
def test_shared_client(self):
|
||||
# Start AWS Lambda Example 1
|
||||
import os
|
||||
|
||||
from pymongo import MongoClient
|
||||
|
||||
client = MongoClient(host=os.environ["MONGODB_URI"])
|
||||
client = self.simple_client(host=os.environ["MONGODB_URI"])
|
||||
|
||||
def lambda_handler(event, context):
|
||||
return client.db.command("ping")
|
||||
@ -203,9 +196,7 @@ class TestAWSLambdaExamples(unittest.TestCase):
|
||||
# Start AWS Lambda Example 2
|
||||
import os
|
||||
|
||||
from pymongo import MongoClient
|
||||
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
host=os.environ["MONGODB_URI"],
|
||||
authSource="$external",
|
||||
authMechanism="MONGODB-AWS",
|
||||
|
||||
@ -23,6 +23,7 @@ import unittest
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from test import PyMongoTestCase
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
@ -56,7 +57,7 @@ globals().update(generate_test_classes(str(TEST_PATH), module=__name__))
|
||||
pytestmark = pytest.mark.auth_oidc
|
||||
|
||||
|
||||
class OIDCTestBase(unittest.TestCase):
|
||||
class OIDCTestBase(PyMongoTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.uri_single = os.environ["MONGODB_URI_SINGLE"]
|
||||
@ -94,6 +95,7 @@ class OIDCTestBase(unittest.TestCase):
|
||||
yield
|
||||
finally:
|
||||
client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off")
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.auth_oidc
|
||||
@ -149,7 +151,9 @@ class TestAuthOIDCHuman(OIDCTestBase):
|
||||
if not len(args):
|
||||
args = [self.uri_single]
|
||||
|
||||
return MongoClient(*args, authmechanismproperties=props, **kwargs)
|
||||
client = self.simple_client(*args, authmechanismproperties=props, **kwargs)
|
||||
|
||||
return client
|
||||
|
||||
def test_1_1_single_principal_implicit_username(self):
|
||||
# Create default OIDC client with authMechanism=MONGODB-OIDC.
|
||||
|
||||
@ -29,13 +29,12 @@ except ImportError:
|
||||
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import OperationFailure
|
||||
|
||||
pytestmark = pytest.mark.mockupdb
|
||||
|
||||
|
||||
class TestCursor(unittest.TestCase):
|
||||
class TestCursor(PyMongoTestCase):
|
||||
def test_getmore_load_balanced(self):
|
||||
server = MockupDB()
|
||||
server.autoresponds(
|
||||
@ -50,7 +49,7 @@ class TestCursor(unittest.TestCase):
|
||||
server.run()
|
||||
self.addCleanup(server.stop)
|
||||
|
||||
client = MongoClient(server.uri, loadBalanced=True)
|
||||
client = self.simple_client(server.uri, loadBalanced=True)
|
||||
self.addCleanup(client.close)
|
||||
collection = client.db.coll
|
||||
cursor = collection.find()
|
||||
@ -77,7 +76,7 @@ class TestRetryableErrorCodeCatch(PyMongoTestCase):
|
||||
self.addCleanup(server.stop)
|
||||
server.autoresponds("ismaster", maxWireVersion=6)
|
||||
|
||||
client = MongoClient(server.uri)
|
||||
client = self.simple_client(server.uri)
|
||||
|
||||
with going(lambda: server.receives(OpMsg({"find": "collection"})).command_err(code=code)):
|
||||
cursor = client.db.collection.find()
|
||||
|
||||
@ -48,8 +48,11 @@ else:
|
||||
def _connect(options):
|
||||
uri = f"mongodb://localhost:27017/?serverSelectionTimeoutMS={TIMEOUT_MS}&tlsCAFile={CA_FILE}&{options}"
|
||||
print(uri)
|
||||
client = pymongo.MongoClient(uri)
|
||||
client.admin.command("ping")
|
||||
try:
|
||||
client = pymongo.MongoClient(uri)
|
||||
client.admin.command("ping")
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
class TestOCSP(unittest.TestCase):
|
||||
|
||||
@ -23,16 +23,14 @@ from urllib.parse import quote_plus
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, SkipTest, client_context, unittest
|
||||
from test.utils import (
|
||||
AllowListEventListener,
|
||||
delay,
|
||||
ignore_deprecations,
|
||||
rs_or_single_client,
|
||||
rs_or_single_client_noauth,
|
||||
single_client,
|
||||
single_client_noauth,
|
||||
from test import (
|
||||
IntegrationTest,
|
||||
PyMongoTestCase,
|
||||
SkipTest,
|
||||
client_context,
|
||||
unittest,
|
||||
)
|
||||
from test.utils import AllowListEventListener, delay, ignore_deprecations
|
||||
|
||||
from pymongo import MongoClient, monitoring
|
||||
from pymongo.auth_shared import _build_credentials_tuple
|
||||
@ -81,7 +79,7 @@ class AutoAuthenticateThread(threading.Thread):
|
||||
self.success = True
|
||||
|
||||
|
||||
class TestGSSAPI(unittest.TestCase):
|
||||
class TestGSSAPI(PyMongoTestCase):
|
||||
mech_properties: str
|
||||
service_realm_required: bool
|
||||
|
||||
@ -138,7 +136,7 @@ class TestGSSAPI(unittest.TestCase):
|
||||
|
||||
if not self.service_realm_required:
|
||||
# Without authMechanismProperties.
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
GSSAPI_HOST,
|
||||
GSSAPI_PORT,
|
||||
username=GSSAPI_PRINCIPAL,
|
||||
@ -149,11 +147,11 @@ class TestGSSAPI(unittest.TestCase):
|
||||
client[GSSAPI_DB].collection.find_one()
|
||||
|
||||
# Log in using URI, without authMechanismProperties.
|
||||
client = MongoClient(uri)
|
||||
client = self.simple_client(uri)
|
||||
client[GSSAPI_DB].collection.find_one()
|
||||
|
||||
# Authenticate with authMechanismProperties.
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
GSSAPI_HOST,
|
||||
GSSAPI_PORT,
|
||||
username=GSSAPI_PRINCIPAL,
|
||||
@ -166,14 +164,14 @@ class TestGSSAPI(unittest.TestCase):
|
||||
|
||||
# Log in using URI, with authMechanismProperties.
|
||||
mech_uri = uri + f"&authMechanismProperties={self.mech_properties}"
|
||||
client = MongoClient(mech_uri)
|
||||
client = self.simple_client(mech_uri)
|
||||
client[GSSAPI_DB].collection.find_one()
|
||||
|
||||
set_name = client_context.replica_set_name
|
||||
if set_name:
|
||||
if not self.service_realm_required:
|
||||
# Without authMechanismProperties
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
GSSAPI_HOST,
|
||||
GSSAPI_PORT,
|
||||
username=GSSAPI_PRINCIPAL,
|
||||
@ -185,11 +183,11 @@ class TestGSSAPI(unittest.TestCase):
|
||||
client[GSSAPI_DB].list_collection_names()
|
||||
|
||||
uri = uri + f"&replicaSet={set_name!s}"
|
||||
client = MongoClient(uri)
|
||||
client = self.simple_client(uri)
|
||||
client[GSSAPI_DB].list_collection_names()
|
||||
|
||||
# With authMechanismProperties
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
GSSAPI_HOST,
|
||||
GSSAPI_PORT,
|
||||
username=GSSAPI_PRINCIPAL,
|
||||
@ -202,13 +200,13 @@ class TestGSSAPI(unittest.TestCase):
|
||||
client[GSSAPI_DB].list_collection_names()
|
||||
|
||||
mech_uri = mech_uri + f"&replicaSet={set_name!s}"
|
||||
client = MongoClient(mech_uri)
|
||||
client = self.simple_client(mech_uri)
|
||||
client[GSSAPI_DB].list_collection_names()
|
||||
|
||||
@ignore_deprecations
|
||||
@client_context.require_sync
|
||||
def test_gssapi_threaded(self):
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
GSSAPI_HOST,
|
||||
GSSAPI_PORT,
|
||||
username=GSSAPI_PRINCIPAL,
|
||||
@ -244,7 +242,7 @@ class TestGSSAPI(unittest.TestCase):
|
||||
|
||||
set_name = client_context.replica_set_name
|
||||
if set_name:
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
GSSAPI_HOST,
|
||||
GSSAPI_PORT,
|
||||
username=GSSAPI_PRINCIPAL,
|
||||
@ -267,14 +265,14 @@ class TestGSSAPI(unittest.TestCase):
|
||||
self.assertTrue(thread.success)
|
||||
|
||||
|
||||
class TestSASLPlain(unittest.TestCase):
|
||||
class TestSASLPlain(PyMongoTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not SASL_HOST or not SASL_USER or not SASL_PASS:
|
||||
raise SkipTest("Must set SASL_HOST, SASL_USER, and SASL_PASS to test SASL")
|
||||
|
||||
def test_sasl_plain(self):
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
SASL_HOST,
|
||||
SASL_PORT,
|
||||
username=SASL_USER,
|
||||
@ -293,12 +291,12 @@ class TestSASLPlain(unittest.TestCase):
|
||||
SASL_PORT,
|
||||
SASL_DB,
|
||||
)
|
||||
client = MongoClient(uri)
|
||||
client = self.simple_client(uri)
|
||||
client.ldap.test.find_one()
|
||||
|
||||
set_name = client_context.replica_set_name
|
||||
if set_name:
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
SASL_HOST,
|
||||
SASL_PORT,
|
||||
replicaSet=set_name,
|
||||
@ -317,7 +315,7 @@ class TestSASLPlain(unittest.TestCase):
|
||||
SASL_DB,
|
||||
str(set_name),
|
||||
)
|
||||
client = MongoClient(uri)
|
||||
client = self.simple_client(uri)
|
||||
client.ldap.test.find_one()
|
||||
|
||||
def test_sasl_plain_bad_credentials(self):
|
||||
@ -331,8 +329,8 @@ class TestSASLPlain(unittest.TestCase):
|
||||
)
|
||||
return uri
|
||||
|
||||
bad_user = MongoClient(auth_string("not-user", SASL_PASS))
|
||||
bad_pwd = MongoClient(auth_string(SASL_USER, "not-pwd"))
|
||||
bad_user = self.simple_client(auth_string("not-user", SASL_PASS))
|
||||
bad_pwd = self.simple_client(auth_string(SASL_USER, "not-pwd"))
|
||||
# OperationFailure raised upon connecting.
|
||||
with self.assertRaises(OperationFailure):
|
||||
bad_user.admin.command("ping")
|
||||
@ -354,7 +352,7 @@ class TestSCRAMSHA1(IntegrationTest):
|
||||
def test_scram_sha1(self):
|
||||
host, port = client_context.host, client_context.port
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
"mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" % (host, port)
|
||||
)
|
||||
client.pymongo_test.command("dbstats")
|
||||
@ -365,7 +363,7 @@ class TestSCRAMSHA1(IntegrationTest):
|
||||
"@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1"
|
||||
"&replicaSet=%s" % (host, port, client_context.replica_set_name)
|
||||
)
|
||||
client = single_client_noauth(uri)
|
||||
client = self.single_client_noauth(uri)
|
||||
client.pymongo_test.command("dbstats")
|
||||
db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY)
|
||||
db.command("dbstats")
|
||||
@ -393,7 +391,7 @@ class TestSCRAM(IntegrationTest):
|
||||
"testscram", "sha256", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"]
|
||||
)
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="sha256", password="pwd", authSource="testscram", event_listeners=[listener]
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
@ -430,36 +428,38 @@ class TestSCRAM(IntegrationTest):
|
||||
)
|
||||
|
||||
# Step 2: verify auth success cases
|
||||
client = rs_or_single_client_noauth(username="sha1", password="pwd", authSource="testscram")
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="sha1", password="pwd", authSource="testscram"
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1"
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="sha256", password="pwd", authSource="testscram"
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256"
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
# Step 2: SCRAM-SHA-1 and SCRAM-SHA-256
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1"
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256"
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
self.listener.reset()
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="both", password="pwd", authSource="testscram", event_listeners=[self.listener]
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
@ -472,19 +472,19 @@ class TestSCRAM(IntegrationTest):
|
||||
self.assertEqual(started.command.get("mechanism"), "SCRAM-SHA-256")
|
||||
|
||||
# Step 3: verify auth failure conditions
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256"
|
||||
)
|
||||
with self.assertRaises(OperationFailure):
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1"
|
||||
)
|
||||
with self.assertRaises(OperationFailure):
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="not-a-user", password="pwd", authSource="testscram"
|
||||
)
|
||||
with self.assertRaises(OperationFailure):
|
||||
@ -497,7 +497,7 @@ class TestSCRAM(IntegrationTest):
|
||||
port,
|
||||
client_context.replica_set_name,
|
||||
)
|
||||
client = single_client_noauth(uri)
|
||||
client = self.single_client_noauth(uri)
|
||||
client.testscram.command("dbstats")
|
||||
db = client.get_database("testscram", read_preference=ReadPreference.SECONDARY)
|
||||
db.command("dbstats")
|
||||
@ -517,12 +517,12 @@ class TestSCRAM(IntegrationTest):
|
||||
"testscram", "IX", "IX", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"]
|
||||
)
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="\u2168", password="\u2163", authSource="testscram"
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="\u2168",
|
||||
password="\u2163",
|
||||
authSource="testscram",
|
||||
@ -530,17 +530,17 @@ class TestSCRAM(IntegrationTest):
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="\u2168", password="IV", authSource="testscram"
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="IX", password="I\u00ADX", authSource="testscram"
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="IX",
|
||||
password="I\u00ADX",
|
||||
authSource="testscram",
|
||||
@ -548,25 +548,29 @@ class TestSCRAM(IntegrationTest):
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
username="IX", password="IX", authSource="testscram", authMechanism="SCRAM-SHA-256"
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
client = self.rs_or_single_client_noauth(
|
||||
"mongodb://\u2168:\u2163@%s:%d/testscram" % (host, port)
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
client = rs_or_single_client_noauth("mongodb://\u2168:IV@%s:%d/testscram" % (host, port))
|
||||
client = self.rs_or_single_client_noauth(
|
||||
"mongodb://\u2168:IV@%s:%d/testscram" % (host, port)
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
client = rs_or_single_client_noauth("mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port))
|
||||
client = self.rs_or_single_client_noauth(
|
||||
"mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port)
|
||||
)
|
||||
client.testscram.command("dbstats")
|
||||
client = rs_or_single_client_noauth("mongodb://IX:IX@%s:%d/testscram" % (host, port))
|
||||
client = self.rs_or_single_client_noauth("mongodb://IX:IX@%s:%d/testscram" % (host, port))
|
||||
client.testscram.command("dbstats")
|
||||
|
||||
def test_cache(self):
|
||||
client = single_client()
|
||||
client = self.single_client()
|
||||
credentials = client.options.pool_options._credentials
|
||||
cache = credentials.cache
|
||||
self.assertIsNotNone(cache)
|
||||
@ -591,8 +595,7 @@ class TestSCRAM(IntegrationTest):
|
||||
coll.insert_one({"_id": 1})
|
||||
|
||||
# The first thread to call find() will authenticate
|
||||
client = rs_or_single_client()
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client()
|
||||
coll = client.db.test
|
||||
threads = []
|
||||
for _ in range(4):
|
||||
@ -619,7 +622,7 @@ class TestAuthURIOptions(IntegrationTest):
|
||||
def test_uri_options(self):
|
||||
# Test default to admin
|
||||
host, port = client_context.host, client_context.port
|
||||
client = rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))
|
||||
client = self.rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))
|
||||
self.assertTrue(client.admin.command("dbstats"))
|
||||
|
||||
if client_context.is_rs:
|
||||
@ -628,14 +631,14 @@ class TestAuthURIOptions(IntegrationTest):
|
||||
port,
|
||||
client_context.replica_set_name,
|
||||
)
|
||||
client = single_client_noauth(uri)
|
||||
client = self.single_client_noauth(uri)
|
||||
self.assertTrue(client.admin.command("dbstats"))
|
||||
db = client.get_database("admin", read_preference=ReadPreference.SECONDARY)
|
||||
self.assertTrue(db.command("dbstats"))
|
||||
|
||||
# Test explicit database
|
||||
uri = "mongodb://user:pass@%s:%d/pymongo_test" % (host, port)
|
||||
client = rs_or_single_client_noauth(uri)
|
||||
client = self.rs_or_single_client_noauth(uri)
|
||||
with self.assertRaises(OperationFailure):
|
||||
client.admin.command("dbstats")
|
||||
self.assertTrue(client.pymongo_test.command("dbstats"))
|
||||
@ -646,7 +649,7 @@ class TestAuthURIOptions(IntegrationTest):
|
||||
port,
|
||||
client_context.replica_set_name,
|
||||
)
|
||||
client = single_client_noauth(uri)
|
||||
client = self.single_client_noauth(uri)
|
||||
with self.assertRaises(OperationFailure):
|
||||
client.admin.command("dbstats")
|
||||
self.assertTrue(client.pymongo_test.command("dbstats"))
|
||||
@ -655,7 +658,7 @@ class TestAuthURIOptions(IntegrationTest):
|
||||
|
||||
# Test authSource
|
||||
uri = "mongodb://user:pass@%s:%d/pymongo_test2?authSource=pymongo_test" % (host, port)
|
||||
client = rs_or_single_client_noauth(uri)
|
||||
client = self.rs_or_single_client_noauth(uri)
|
||||
with self.assertRaises(OperationFailure):
|
||||
client.pymongo_test2.command("dbstats")
|
||||
self.assertTrue(client.pymongo_test.command("dbstats"))
|
||||
@ -665,7 +668,7 @@ class TestAuthURIOptions(IntegrationTest):
|
||||
"mongodb://user:pass@%s:%d/pymongo_test2?replicaSet="
|
||||
"%s;authSource=pymongo_test" % (host, port, client_context.replica_set_name)
|
||||
)
|
||||
client = single_client_noauth(uri)
|
||||
client = self.single_client_noauth(uri)
|
||||
with self.assertRaises(OperationFailure):
|
||||
client.pymongo_test2.command("dbstats")
|
||||
self.assertTrue(client.pymongo_test.command("dbstats"))
|
||||
|
||||
@ -20,6 +20,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from test import PyMongoTestCase
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
@ -34,7 +35,7 @@ _IS_SYNC = True
|
||||
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth")
|
||||
|
||||
|
||||
class TestAuthSpec(unittest.TestCase):
|
||||
class TestAuthSpec(PyMongoTestCase):
|
||||
pass
|
||||
|
||||
|
||||
@ -54,7 +55,7 @@ def create_test(test_case):
|
||||
warnings.simplefilter("default")
|
||||
self.assertRaises(Exception, MongoClient, uri, connect=False)
|
||||
else:
|
||||
client = MongoClient(uri, connect=False)
|
||||
client = self.simple_client(uri, connect=False)
|
||||
credentials = client.options.pool_options._credentials
|
||||
if credential is None:
|
||||
self.assertIsNone(credentials)
|
||||
|
||||
@ -24,22 +24,13 @@ from pymongo.synchronous.mongo_client import MongoClient
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, remove_all_users, unittest
|
||||
from test.utils import (
|
||||
rs_or_single_client_noauth,
|
||||
single_client,
|
||||
wait_until,
|
||||
)
|
||||
from test.utils import wait_until
|
||||
|
||||
from bson.binary import Binary, UuidRepresentation
|
||||
from bson.codec_options import CodecOptions
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.common import partition_node
|
||||
from pymongo.errors import (
|
||||
BulkWriteError,
|
||||
ConfigurationError,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.errors import BulkWriteError, ConfigurationError, InvalidOperation, OperationFailure
|
||||
from pymongo.operations import *
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.write_concern import WriteConcern
|
||||
@ -913,7 +904,7 @@ class TestBulkAuthorization(BulkAuthorizationTestBase):
|
||||
def test_readonly(self):
|
||||
# We test that an authorization failure aborts the batch and is raised
|
||||
# as OperationFailure.
|
||||
cli = rs_or_single_client_noauth(
|
||||
cli = self.rs_or_single_client_noauth(
|
||||
username="readonly", password="pw", authSource="pymongo_test"
|
||||
)
|
||||
coll = cli.pymongo_test.test
|
||||
@ -924,7 +915,7 @@ class TestBulkAuthorization(BulkAuthorizationTestBase):
|
||||
def test_no_remove(self):
|
||||
# We test that an authorization failure aborts the batch and is raised
|
||||
# as OperationFailure.
|
||||
cli = rs_or_single_client_noauth(
|
||||
cli = self.rs_or_single_client_noauth(
|
||||
username="noremove", password="pw", authSource="pymongo_test"
|
||||
)
|
||||
coll = cli.pymongo_test.test
|
||||
@ -952,7 +943,7 @@ class TestBulkWriteConcern(BulkTestBase):
|
||||
if cls.w is not None and cls.w > 1:
|
||||
for member in (client_context.hello)["hosts"]:
|
||||
if member != (client_context.hello)["primary"]:
|
||||
cls.secondary = single_client(*partition_node(member))
|
||||
cls.secondary = cls.unmanaged_single_client(*partition_node(member))
|
||||
break
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -28,12 +28,17 @@ from typing import no_type_check
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, Version, client_context, unittest
|
||||
from test import (
|
||||
IntegrationTest,
|
||||
PyMongoTestCase,
|
||||
Version,
|
||||
client_context,
|
||||
unittest,
|
||||
)
|
||||
from test.unified_format import generate_test_classes
|
||||
from test.utils import (
|
||||
AllowListEventListener,
|
||||
EventListener,
|
||||
rs_or_single_client,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
@ -69,8 +74,7 @@ class TestChangeStreamBase(IntegrationTest):
|
||||
def client_with_listener(self, *commands):
|
||||
"""Return a client with a AllowListEventListener."""
|
||||
listener = AllowListEventListener(*commands)
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
return client, listener
|
||||
|
||||
def watched_collection(self, *args, **kwargs):
|
||||
@ -174,7 +178,7 @@ class APITestsMixin:
|
||||
@no_type_check
|
||||
def test_try_next_runs_one_getmore(self):
|
||||
listener = EventListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
# Connect to the cluster.
|
||||
client.admin.command("ping")
|
||||
listener.reset()
|
||||
@ -232,7 +236,7 @@ class APITestsMixin:
|
||||
@no_type_check
|
||||
def test_batch_size_is_honored(self):
|
||||
listener = EventListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
# Connect to the cluster.
|
||||
client.admin.command("ping")
|
||||
listener.reset()
|
||||
@ -473,7 +477,7 @@ class ProseSpecTestsMixin:
|
||||
@no_type_check
|
||||
def _client_with_listener(self, *commands):
|
||||
listener = AllowListEventListener(*commands)
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
client = PyMongoTestCase.unmanaged_rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
return client, listener
|
||||
|
||||
@ -1111,7 +1115,7 @@ class TestAllLegacyScenarios(IntegrationTest):
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.listener = AllowListEventListener("aggregate", "getMore")
|
||||
cls.client = rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -27,7 +27,6 @@ from test import (
|
||||
)
|
||||
from test.utils import (
|
||||
OvertCommandListener,
|
||||
rs_or_single_client,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -38,7 +37,6 @@ from pymongo.errors import (
|
||||
InvalidOperation,
|
||||
NetworkTimeout,
|
||||
)
|
||||
from pymongo.monitoring import *
|
||||
from pymongo.operations import *
|
||||
from pymongo.synchronous.client_bulk import _ClientBulk
|
||||
from pymongo.write_concern import WriteConcern
|
||||
@ -97,8 +95,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
|
||||
@client_context.require_no_serverless
|
||||
def test_batch_splits_if_num_operations_too_large(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
models = []
|
||||
for _ in range(self.max_write_batch_size + 1):
|
||||
@ -123,8 +120,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
|
||||
@client_context.require_no_serverless
|
||||
def test_batch_splits_if_ops_payload_too_large(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
models = []
|
||||
num_models = int(self.max_message_size_bytes / self.max_bson_object_size + 1)
|
||||
@ -157,11 +153,10 @@ class TestClientBulkWriteCRUD(IntegrationTest):
|
||||
@client_context.require_failCommand_fail_point
|
||||
def test_collects_write_concern_errors_across_batches(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(
|
||||
client = self.rs_or_single_client(
|
||||
event_listeners=[listener],
|
||||
retryWrites=False,
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
|
||||
fail_command = {
|
||||
"configureFailPoint": "failCommand",
|
||||
@ -200,8 +195,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
|
||||
@client_context.require_no_serverless
|
||||
def test_collects_write_errors_across_batches_unordered(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addCleanup(collection.drop)
|
||||
@ -231,8 +225,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
|
||||
@client_context.require_no_serverless
|
||||
def test_collects_write_errors_across_batches_ordered(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addCleanup(collection.drop)
|
||||
@ -262,8 +255,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
|
||||
@client_context.require_no_serverless
|
||||
def test_handles_cursor_requiring_getMore(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addCleanup(collection.drop)
|
||||
@ -304,8 +296,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
|
||||
@client_context.require_no_standalone
|
||||
def test_handles_cursor_requiring_getMore_within_transaction(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addCleanup(collection.drop)
|
||||
@ -348,8 +339,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
|
||||
@client_context.require_failCommand_fail_point
|
||||
def test_handles_getMore_error(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addCleanup(collection.drop)
|
||||
@ -403,8 +393,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
|
||||
@client_context.require_no_serverless
|
||||
def test_returns_error_if_unacknowledged_too_large_insert(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
b_repeated = "b" * self.max_bson_object_size
|
||||
|
||||
@ -460,8 +449,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
|
||||
@client_context.require_no_serverless
|
||||
def test_no_batch_splits_if_new_namespace_is_not_too_large(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
num_models, models = self._setup_namespace_test_models()
|
||||
models.append(
|
||||
@ -492,8 +480,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
|
||||
@client_context.require_no_serverless
|
||||
def test_batch_splits_if_new_namespace_is_too_large(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
|
||||
num_models, models = self._setup_namespace_test_models()
|
||||
c_repeated = "c" * 200
|
||||
@ -530,8 +517,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
|
||||
@client_context.require_version_min(8, 0, 0, -24)
|
||||
@client_context.require_no_serverless
|
||||
def test_returns_error_if_no_writes_can_be_added_to_ops(self):
|
||||
client = rs_or_single_client()
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client()
|
||||
|
||||
# Document too large.
|
||||
b_repeated = "b" * self.max_message_size_bytes
|
||||
@ -554,8 +540,7 @@ class TestClientBulkWriteCRUD(IntegrationTest):
|
||||
key_vault_namespace="db.coll",
|
||||
kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}},
|
||||
)
|
||||
client = rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(auto_encryption_opts=opts)
|
||||
|
||||
models = [InsertOne(namespace="db.coll", document={"a": "b"})]
|
||||
with self.assertRaises(InvalidOperation) as context:
|
||||
@ -580,7 +565,7 @@ class TestClientBulkWriteCSOT(IntegrationTest):
|
||||
def test_timeout_in_multi_batch_bulk_write(self):
|
||||
_OVERHEAD = 500
|
||||
|
||||
internal_client = rs_or_single_client(timeoutMS=None)
|
||||
internal_client = self.rs_or_single_client(timeoutMS=None)
|
||||
self.addCleanup(internal_client.close)
|
||||
|
||||
collection = internal_client.db["coll"]
|
||||
@ -605,14 +590,13 @@ class TestClientBulkWriteCSOT(IntegrationTest):
|
||||
)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(
|
||||
client = self.rs_or_single_client(
|
||||
event_listeners=[listener],
|
||||
readConcernLevel="majority",
|
||||
readPreference="primary",
|
||||
timeoutMS=2000,
|
||||
w="majority",
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
client.admin.command("ping") # Init the client first.
|
||||
with self.assertRaises(ClientBulkWriteException) as context:
|
||||
client.bulk_write(models=models)
|
||||
|
||||
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
import functools
|
||||
import warnings
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test.utils import EventListener, rs_or_single_client
|
||||
from test.utils import EventListener
|
||||
from typing import Any
|
||||
|
||||
from pymongo.collation import (
|
||||
@ -99,7 +99,7 @@ class TestCollation(IntegrationTest):
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.listener = EventListener()
|
||||
cls.client = rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.db = cls.client.pymongo_test
|
||||
cls.collation = Collation("en_US")
|
||||
cls.warn_context = warnings.catch_warnings()
|
||||
|
||||
@ -29,6 +29,7 @@ sys.path[0:0] = [""]
|
||||
|
||||
from test import ( # TODO: fix sync imports in PYTHON-4528
|
||||
IntegrationTest,
|
||||
UnitTest,
|
||||
client_context,
|
||||
unittest,
|
||||
)
|
||||
@ -37,8 +38,6 @@ from test.utils import (
|
||||
EventListener,
|
||||
get_pool,
|
||||
is_mongos,
|
||||
rs_or_single_client,
|
||||
single_client,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
@ -81,14 +80,20 @@ from pymongo.write_concern import WriteConcern
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestCollectionNoConnect(unittest.TestCase):
|
||||
class TestCollectionNoConnect(UnitTest):
|
||||
"""Test Collection features on a client that does not connect."""
|
||||
|
||||
db: Database
|
||||
client: MongoClient
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.db = MongoClient(connect=False).pymongo_test
|
||||
def _setup_class(cls):
|
||||
cls.client = MongoClient(connect=False)
|
||||
cls.db = cls.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.client.close()
|
||||
|
||||
def test_collection(self):
|
||||
self.assertRaises(TypeError, Collection, self.db, 5)
|
||||
@ -1800,8 +1805,7 @@ class TestCollection(IntegrationTest):
|
||||
# Insert enough documents to require more than one batch
|
||||
self.db.test.insert_many([{"i": i} for i in range(150)])
|
||||
|
||||
client = rs_or_single_client(maxPoolSize=1)
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(maxPoolSize=1)
|
||||
pool = get_pool(client)
|
||||
|
||||
# Make sure the socket is returned after exhaustion.
|
||||
@ -2077,7 +2081,7 @@ class TestCollection(IntegrationTest):
|
||||
|
||||
def test_find_one_and_write_concern(self):
|
||||
listener = EventListener()
|
||||
db = (single_client(event_listeners=[listener]))[self.db.name]
|
||||
db = (self.single_client(event_listeners=[listener]))[self.db.name]
|
||||
# non-default WriteConcern.
|
||||
c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0))
|
||||
# default WriteConcern.
|
||||
|
||||
@ -22,7 +22,7 @@ import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test.utils import EventListener, rs_or_single_client
|
||||
from test.utils import EventListener
|
||||
|
||||
from bson.dbref import DBRef
|
||||
from pymongo.operations import IndexModel
|
||||
@ -109,7 +109,7 @@ class TestComment(IntegrationTest):
|
||||
@client_context.require_replica_set
|
||||
def test_database_helpers(self):
|
||||
listener = EventListener()
|
||||
db = rs_or_single_client(event_listeners=[listener]).db
|
||||
db = self.rs_or_single_client(event_listeners=[listener]).db
|
||||
helpers = [
|
||||
(db.watch, []),
|
||||
(db.command, ["hello"]),
|
||||
@ -126,7 +126,7 @@ class TestComment(IntegrationTest):
|
||||
@client_context.require_replica_set
|
||||
def test_client_helpers(self):
|
||||
listener = EventListener()
|
||||
cli = rs_or_single_client(event_listeners=[listener])
|
||||
cli = self.rs_or_single_client(event_listeners=[listener])
|
||||
helpers = [
|
||||
(cli.watch, []),
|
||||
(cli.list_databases, []),
|
||||
@ -141,7 +141,7 @@ class TestComment(IntegrationTest):
|
||||
@client_context.require_version_min(4, 7, -1)
|
||||
def test_collection_helpers(self):
|
||||
listener = EventListener()
|
||||
db = rs_or_single_client(event_listeners=[listener])[self.db.name]
|
||||
db = self.rs_or_single_client(event_listeners=[listener])[self.db.name]
|
||||
coll = db.get_collection("test")
|
||||
|
||||
helpers = [
|
||||
|
||||
@ -21,7 +21,6 @@ import uuid
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, connected, unittest
|
||||
from test.utils import rs_or_single_client, single_client
|
||||
|
||||
from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation
|
||||
from bson.codec_options import CodecOptions
|
||||
@ -111,10 +110,10 @@ class TestCommon(IntegrationTest):
|
||||
)
|
||||
|
||||
def test_write_concern(self):
|
||||
c = rs_or_single_client(connect=False)
|
||||
c = self.rs_or_single_client(connect=False)
|
||||
self.assertEqual(WriteConcern(), c.write_concern)
|
||||
|
||||
c = rs_or_single_client(connect=False, w=2, wTimeoutMS=1000)
|
||||
c = self.rs_or_single_client(connect=False, w=2, wTimeoutMS=1000)
|
||||
wc = WriteConcern(w=2, wtimeout=1000)
|
||||
self.assertEqual(wc, c.write_concern)
|
||||
|
||||
@ -134,7 +133,7 @@ class TestCommon(IntegrationTest):
|
||||
|
||||
def test_mongo_client(self):
|
||||
pair = client_context.pair
|
||||
m = rs_or_single_client(w=0)
|
||||
m = self.rs_or_single_client(w=0)
|
||||
coll = m.pymongo_test.write_concern_test
|
||||
coll.drop()
|
||||
doc = {"_id": ObjectId()}
|
||||
@ -143,17 +142,19 @@ class TestCommon(IntegrationTest):
|
||||
coll = coll.with_options(write_concern=WriteConcern(w=1))
|
||||
self.assertRaises(OperationFailure, coll.insert_one, doc)
|
||||
|
||||
m = rs_or_single_client()
|
||||
m = self.rs_or_single_client()
|
||||
coll = m.pymongo_test.write_concern_test
|
||||
new_coll = coll.with_options(write_concern=WriteConcern(w=0))
|
||||
self.assertTrue(new_coll.insert_one(doc))
|
||||
self.assertRaises(OperationFailure, coll.insert_one, doc)
|
||||
|
||||
m = rs_or_single_client(f"mongodb://{pair}/", replicaSet=client_context.replica_set_name)
|
||||
m = self.rs_or_single_client(
|
||||
f"mongodb://{pair}/", replicaSet=client_context.replica_set_name
|
||||
)
|
||||
|
||||
coll = m.pymongo_test.write_concern_test
|
||||
self.assertRaises(OperationFailure, coll.insert_one, doc)
|
||||
m = rs_or_single_client(
|
||||
m = self.rs_or_single_client(
|
||||
f"mongodb://{pair}/?w=0", replicaSet=client_context.replica_set_name
|
||||
)
|
||||
|
||||
@ -161,8 +162,8 @@ class TestCommon(IntegrationTest):
|
||||
coll.insert_one(doc)
|
||||
|
||||
# Equality tests
|
||||
direct = connected(single_client(w=0))
|
||||
direct2 = connected(single_client(f"mongodb://{pair}/?w=0", **self.credentials))
|
||||
direct = connected(self.single_client(w=0))
|
||||
direct2 = connected(self.single_client(f"mongodb://{pair}/?w=0", **self.credentials))
|
||||
self.assertEqual(direct, direct2)
|
||||
self.assertFalse(direct != direct2)
|
||||
|
||||
|
||||
@ -30,9 +30,6 @@ from test.utils import (
|
||||
client_context,
|
||||
get_pool,
|
||||
get_pools,
|
||||
rs_or_single_client,
|
||||
single_client,
|
||||
single_client_noauth,
|
||||
wait_until,
|
||||
)
|
||||
from test.utils_spec_runner import SpecRunnerThread
|
||||
@ -250,7 +247,7 @@ class TestCMAP(IntegrationTest):
|
||||
else:
|
||||
kill_cursor_frequency = interval / 1000.0
|
||||
with client_knobs(kill_cursor_frequency=kill_cursor_frequency, min_heartbeat_interval=0.05):
|
||||
client = single_client(**opts)
|
||||
client = self.single_client(**opts)
|
||||
# Update the SD to a known type because the DummyMonitor will not.
|
||||
# Note we cannot simply call topology.on_change because that would
|
||||
# internally call pool.ready() which introduces unexpected
|
||||
@ -323,13 +320,13 @@ class TestCMAP(IntegrationTest):
|
||||
# Prose tests. Numbers correspond to the prose test number in the spec.
|
||||
#
|
||||
def test_1_client_connection_pool_options(self):
|
||||
client = rs_or_single_client(**self.POOL_OPTIONS)
|
||||
client = self.rs_or_single_client(**self.POOL_OPTIONS)
|
||||
self.addCleanup(client.close)
|
||||
pool_opts = get_pool(client).opts
|
||||
self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS)
|
||||
|
||||
def test_2_all_client_pools_have_same_options(self):
|
||||
client = rs_or_single_client(**self.POOL_OPTIONS)
|
||||
client = self.rs_or_single_client(**self.POOL_OPTIONS)
|
||||
self.addCleanup(client.close)
|
||||
client.admin.command("ping")
|
||||
# Discover at least one secondary.
|
||||
@ -345,14 +342,14 @@ class TestCMAP(IntegrationTest):
|
||||
def test_3_uri_connection_pool_options(self):
|
||||
opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()])
|
||||
uri = f"mongodb://{client_context.pair}/?{opts}"
|
||||
client = rs_or_single_client(uri)
|
||||
client = self.rs_or_single_client(uri)
|
||||
self.addCleanup(client.close)
|
||||
pool_opts = get_pool(client).opts
|
||||
self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS)
|
||||
|
||||
def test_4_subscribe_to_events(self):
|
||||
listener = CMAPListener()
|
||||
client = single_client(event_listeners=[listener])
|
||||
client = self.single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
self.assertEqual(listener.event_count(PoolCreatedEvent), 1)
|
||||
|
||||
@ -376,7 +373,7 @@ class TestCMAP(IntegrationTest):
|
||||
|
||||
def test_5_check_out_fails_connection_error(self):
|
||||
listener = CMAPListener()
|
||||
client = single_client(event_listeners=[listener])
|
||||
client = self.single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
pool = get_pool(client)
|
||||
|
||||
@ -403,7 +400,7 @@ class TestCMAP(IntegrationTest):
|
||||
@client_context.require_no_fips
|
||||
def test_5_check_out_fails_auth_error(self):
|
||||
listener = CMAPListener()
|
||||
client = single_client_noauth(
|
||||
client = self.single_client_noauth(
|
||||
username="notauser", password="fail", event_listeners=[listener]
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
@ -449,7 +446,7 @@ class TestCMAP(IntegrationTest):
|
||||
|
||||
def test_close_leaves_pool_unpaused(self):
|
||||
listener = CMAPListener()
|
||||
client = single_client(event_listeners=[listener])
|
||||
client = self.single_client(event_listeners=[listener])
|
||||
client.admin.command("ping")
|
||||
pool = get_pool(client)
|
||||
client.close()
|
||||
|
||||
@ -24,7 +24,6 @@ from test.utils import (
|
||||
CMAPListener,
|
||||
ensure_all_connected,
|
||||
repl_set_step_down,
|
||||
rs_or_single_client,
|
||||
)
|
||||
|
||||
from bson import SON
|
||||
@ -43,7 +42,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.listener = CMAPListener()
|
||||
cls.client = rs_or_single_client(
|
||||
cls.client = cls.unmanaged_rs_or_single_client(
|
||||
event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500
|
||||
)
|
||||
|
||||
|
||||
@ -34,8 +34,8 @@ from test.utils import (
|
||||
AllowListEventListener,
|
||||
EventListener,
|
||||
OvertCommandListener,
|
||||
delay,
|
||||
ignore_deprecations,
|
||||
rs_or_single_client,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
@ -43,7 +43,7 @@ from bson import decode_all
|
||||
from bson.code import Code
|
||||
from pymongo import ASCENDING, DESCENDING
|
||||
from pymongo.collation import Collation
|
||||
from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure
|
||||
from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure, PyMongoError
|
||||
from pymongo.operations import _IndexList
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
@ -230,7 +230,7 @@ class TestCursor(IntegrationTest):
|
||||
self.assertEqual(90, cursor._max_await_time_ms)
|
||||
|
||||
listener = AllowListEventListener("find", "getMore")
|
||||
coll = (rs_or_single_client(event_listeners=[listener]))[self.db.name].pymongo_test
|
||||
coll = (self.rs_or_single_client(event_listeners=[listener]))[self.db.name].pymongo_test
|
||||
|
||||
# Tailable_defaults.
|
||||
coll.find(cursor_type=CursorType.TAILABLE_AWAIT).to_list()
|
||||
@ -345,8 +345,7 @@ class TestCursor(IntegrationTest):
|
||||
def test_explain_with_read_concern(self):
|
||||
# Do not add readConcern level to explain.
|
||||
listener = AllowListEventListener("explain")
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local"))
|
||||
self.assertTrue(coll.find().explain())
|
||||
started = listener.started_events
|
||||
@ -1252,8 +1251,7 @@ class TestCursor(IntegrationTest):
|
||||
self.client._process_periodic_tasks()
|
||||
|
||||
listener = AllowListEventListener("killCursors")
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
coll = client[self.db.name].test_close_kills_cursors
|
||||
|
||||
# Add some test data.
|
||||
@ -1291,8 +1289,7 @@ class TestCursor(IntegrationTest):
|
||||
@client_context.require_failCommand_appName
|
||||
def test_timeout_kills_cursor_synchronously(self):
|
||||
listener = AllowListEventListener("killCursors")
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
coll = client[self.db.name].test_timeout_kills_cursor
|
||||
|
||||
# Add some test data.
|
||||
@ -1349,8 +1346,7 @@ class TestCursor(IntegrationTest):
|
||||
|
||||
def test_getMore_does_not_send_readPreference(self):
|
||||
listener = AllowListEventListener("find", "getMore")
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
# We never send primary read preference so override the default.
|
||||
coll = client[self.db.name].get_collection(
|
||||
"test", read_preference=ReadPreference.PRIMARY_PREFERRED
|
||||
@ -1406,6 +1402,18 @@ class TestCursor(IntegrationTest):
|
||||
docs = c.to_list(3)
|
||||
self.assertEqual(len(docs), 2)
|
||||
|
||||
def test_to_list_csot_applied(self):
|
||||
client = self.single_client(timeoutMS=500)
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(2):
|
||||
client.admin.command("ping")
|
||||
coll = client.pymongo.test
|
||||
coll.insert_many([{} for _ in range(5)])
|
||||
cursor = coll.find({"$where": delay(1)})
|
||||
with self.assertRaises(PyMongoError) as ctx:
|
||||
cursor.to_list()
|
||||
self.assertTrue(ctx.exception.timeout)
|
||||
|
||||
@client_context.require_change_streams
|
||||
def test_command_cursor_to_list(self):
|
||||
# Set maxAwaitTimeMS=1 to speed up the test.
|
||||
@ -1435,6 +1443,25 @@ class TestCursor(IntegrationTest):
|
||||
result = db.test.aggregate([pipeline])
|
||||
self.assertEqual(len(result.to_list(1)), 1)
|
||||
|
||||
@client_context.require_failCommand_blockConnection
|
||||
def test_command_cursor_to_list_csot_applied(self):
|
||||
client = self.single_client(timeoutMS=500)
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(2):
|
||||
client.admin.command("ping")
|
||||
coll = client.pymongo.test
|
||||
coll.insert_many([{} for _ in range(5)])
|
||||
fail_command = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 5},
|
||||
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 1000},
|
||||
}
|
||||
cursor = coll.aggregate([], batchSize=1)
|
||||
with self.fail_point(fail_command):
|
||||
with self.assertRaises(PyMongoError) as ctx:
|
||||
cursor.to_list()
|
||||
self.assertTrue(ctx.exception.timeout)
|
||||
|
||||
|
||||
class TestRawBatchCursor(IntegrationTest):
|
||||
def test_find_raw(self):
|
||||
@ -1454,7 +1481,7 @@ class TestRawBatchCursor(IntegrationTest):
|
||||
c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
with client.start_session() as session:
|
||||
with session.start_transaction():
|
||||
batches = (
|
||||
@ -1484,7 +1511,7 @@ class TestRawBatchCursor(IntegrationTest):
|
||||
c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
client = self.rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
with self.fail_point(
|
||||
{"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}}
|
||||
):
|
||||
@ -1505,7 +1532,7 @@ class TestRawBatchCursor(IntegrationTest):
|
||||
c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
client = self.rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
db = client[self.db.name]
|
||||
with client.start_session(snapshot=True) as session:
|
||||
db.test.distinct("x", {}, session=session)
|
||||
@ -1566,7 +1593,7 @@ class TestRawBatchCursor(IntegrationTest):
|
||||
|
||||
def test_monitoring(self):
|
||||
listener = EventListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
c = client.pymongo_test.test
|
||||
c.drop()
|
||||
c.insert_many([{"_id": i} for i in range(10)])
|
||||
@ -1632,7 +1659,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
|
||||
c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
with client.start_session() as session:
|
||||
with session.start_transaction():
|
||||
batches = (
|
||||
@ -1663,7 +1690,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
|
||||
c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
client = self.rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
with self.fail_point(
|
||||
{"mode": {"times": 1}, "data": {"failCommands": ["aggregate"], "closeConnection": True}}
|
||||
):
|
||||
@ -1687,7 +1714,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
|
||||
c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
client = self.rs_or_single_client(event_listeners=[listener], retryReads=True)
|
||||
db = client[self.db.name]
|
||||
with client.start_session(snapshot=True) as session:
|
||||
db.test.distinct("x", {}, session=session)
|
||||
@ -1733,7 +1760,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
|
||||
|
||||
def test_monitoring(self):
|
||||
listener = EventListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
c = client.pymongo_test.test
|
||||
c.drop()
|
||||
c.insert_many([{"_id": i} for i in range(10)])
|
||||
@ -1777,8 +1804,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
|
||||
@client_context.require_no_mongos
|
||||
def test_exhaust_cursor_db_set(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
c = client.pymongo_test.test
|
||||
c.delete_many({})
|
||||
c.insert_many([{"_id": i} for i in range(3)])
|
||||
|
||||
@ -27,7 +27,6 @@ sys.path[0:0] = [""]
|
||||
|
||||
from test import client_context, unittest
|
||||
from test.test_client import IntegrationTest
|
||||
from test.utils import rs_client
|
||||
|
||||
from bson import (
|
||||
_BUILT_IN_TYPES,
|
||||
@ -971,7 +970,7 @@ class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustom
|
||||
if codec_options:
|
||||
kwargs["type_registry"] = codec_options.type_registry
|
||||
kwargs["document_class"] = codec_options.document_class
|
||||
self.watched_target = rs_client(*args, **kwargs)
|
||||
self.watched_target = self.rs_client(*args, **kwargs)
|
||||
self.addCleanup(self.watched_target.close)
|
||||
self.input_target = self.watched_target[self.db.name].test
|
||||
# Insert a record to ensure db, coll are created.
|
||||
|
||||
@ -27,8 +27,6 @@ from test import IntegrationTest, client_context, unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
from test.utils import (
|
||||
OvertCommandListener,
|
||||
rs_client_noauth,
|
||||
rs_or_single_client,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.data_lake
|
||||
@ -65,7 +63,7 @@ class TestDataLakeProse(IntegrationTest):
|
||||
# Test killCursors
|
||||
def test_1(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
cursor = client[self.TEST_DB][self.TEST_COLLECTION].find({}, batch_size=2)
|
||||
next(cursor)
|
||||
|
||||
@ -90,13 +88,13 @@ class TestDataLakeProse(IntegrationTest):
|
||||
|
||||
# Test no auth
|
||||
def test_2(self):
|
||||
client = rs_client_noauth()
|
||||
client = self.rs_client_noauth()
|
||||
client.admin.command("ping")
|
||||
|
||||
# Test with auth
|
||||
def test_3(self):
|
||||
for mechanism in ["SCRAM-SHA-1", "SCRAM-SHA-256"]:
|
||||
client = rs_or_single_client(authMechanism=mechanism)
|
||||
client = self.rs_or_single_client(authMechanism=mechanism)
|
||||
client[self.TEST_DB][self.TEST_COLLECTION].find_one()
|
||||
|
||||
|
||||
|
||||
@ -28,7 +28,6 @@ from test.test_custom_types import DECIMAL_CODECOPTS
|
||||
from test.utils import (
|
||||
IMPOSSIBLE_WRITE_CONCERN,
|
||||
OvertCommandListener,
|
||||
rs_or_single_client,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
@ -207,7 +206,7 @@ class TestDatabase(IntegrationTest):
|
||||
|
||||
def test_list_collection_names_filter(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
db = client[self.db.name]
|
||||
db.capped.drop()
|
||||
db.create_collection("capped", capped=True, size=4096)
|
||||
@ -234,8 +233,7 @@ class TestDatabase(IntegrationTest):
|
||||
|
||||
def test_check_exists(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
db = client[self.db.name]
|
||||
db.drop_collection("unique")
|
||||
db.create_collection("unique", check_exists=True)
|
||||
@ -323,7 +321,7 @@ class TestDatabase(IntegrationTest):
|
||||
self.client.drop_database("pymongo_test")
|
||||
|
||||
def test_list_collection_names_single_socket(self):
|
||||
client = rs_or_single_client(maxPoolSize=1)
|
||||
client = self.rs_or_single_client(maxPoolSize=1)
|
||||
client.drop_database("test_collection_names_single_socket")
|
||||
db = client.test_collection_names_single_socket
|
||||
for i in range(200):
|
||||
|
||||
@ -22,7 +22,7 @@ import threading
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, unittest
|
||||
from test import IntegrationTest, PyMongoTestCase, unittest
|
||||
from test.pymongo_mocks import DummyMonitor
|
||||
from test.unified_format import generate_test_classes
|
||||
from test.utils import (
|
||||
@ -32,9 +32,7 @@ from test.utils import (
|
||||
assertion_context,
|
||||
client_context,
|
||||
get_pool,
|
||||
rs_or_single_client,
|
||||
server_name_to_type,
|
||||
single_client,
|
||||
wait_until,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
@ -272,7 +270,7 @@ class TestIgnoreStaleErrors(IntegrationTest):
|
||||
def test_ignore_stale_connection_errors(self):
|
||||
N_THREADS = 5
|
||||
barrier = threading.Barrier(N_THREADS, timeout=30)
|
||||
client = rs_or_single_client(minPoolSize=N_THREADS)
|
||||
client = self.rs_or_single_client(minPoolSize=N_THREADS)
|
||||
self.addCleanup(client.close)
|
||||
|
||||
# Wait for initial discovery.
|
||||
@ -319,7 +317,7 @@ class TestPoolManagement(IntegrationTest):
|
||||
def test_pool_unpause(self):
|
||||
# This test implements the prose test "Connection Pool Management"
|
||||
listener = CMAPHeartbeatListener()
|
||||
client = single_client(
|
||||
client = self.single_client(
|
||||
appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener]
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
@ -353,7 +351,7 @@ class TestServerMonitoringMode(IntegrationTest):
|
||||
super().setUp()
|
||||
|
||||
def test_rtt_connection_is_enabled_stream(self):
|
||||
client = rs_or_single_client(serverMonitoringMode="stream")
|
||||
client = self.rs_or_single_client(serverMonitoringMode="stream")
|
||||
self.addCleanup(client.close)
|
||||
client.admin.command("ping")
|
||||
|
||||
@ -373,7 +371,7 @@ class TestServerMonitoringMode(IntegrationTest):
|
||||
wait_until(predicate, "find all RTT monitors")
|
||||
|
||||
def test_rtt_connection_is_disabled_poll(self):
|
||||
client = rs_or_single_client(serverMonitoringMode="poll")
|
||||
client = self.rs_or_single_client(serverMonitoringMode="poll")
|
||||
self.addCleanup(client.close)
|
||||
self.assert_rtt_connection_is_disabled(client)
|
||||
|
||||
@ -387,7 +385,7 @@ class TestServerMonitoringMode(IntegrationTest):
|
||||
]
|
||||
for env in envs:
|
||||
with patch.dict("os.environ", env):
|
||||
client = rs_or_single_client(serverMonitoringMode="auto")
|
||||
client = self.rs_or_single_client(serverMonitoringMode="auto")
|
||||
self.addCleanup(client.close)
|
||||
self.assert_rtt_connection_is_disabled(client)
|
||||
|
||||
@ -415,7 +413,7 @@ class TCPServer(socketserver.TCPServer):
|
||||
self.server_close()
|
||||
|
||||
|
||||
class TestHeartbeatStartOrdering(unittest.TestCase):
|
||||
class TestHeartbeatStartOrdering(PyMongoTestCase):
|
||||
def test_heartbeat_start_ordering(self):
|
||||
events = []
|
||||
listener = HeartbeatEventsListListener(events)
|
||||
@ -423,7 +421,7 @@ class TestHeartbeatStartOrdering(unittest.TestCase):
|
||||
server.events = events
|
||||
server_thread = threading.Thread(target=server.handle_request_and_shutdown)
|
||||
server_thread.start()
|
||||
_c = MongoClient(
|
||||
_c = self.simple_client(
|
||||
"mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,)
|
||||
)
|
||||
server_thread.join()
|
||||
|
||||
@ -22,16 +22,15 @@ import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test import IntegrationTest, PyMongoTestCase, client_context, unittest
|
||||
from test.utils import wait_until
|
||||
|
||||
from pymongo.common import validate_read_preference_tags
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.uri_parser import parse_uri, split_hosts
|
||||
|
||||
|
||||
class TestDNSRepl(unittest.TestCase):
|
||||
class TestDNSRepl(PyMongoTestCase):
|
||||
TEST_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "replica-set"
|
||||
)
|
||||
@ -42,7 +41,7 @@ class TestDNSRepl(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class TestDNSLoadBalanced(unittest.TestCase):
|
||||
class TestDNSLoadBalanced(PyMongoTestCase):
|
||||
TEST_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "load-balanced"
|
||||
)
|
||||
@ -53,7 +52,7 @@ class TestDNSLoadBalanced(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class TestDNSSharded(unittest.TestCase):
|
||||
class TestDNSSharded(PyMongoTestCase):
|
||||
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "sharded")
|
||||
load_balanced = False
|
||||
|
||||
@ -120,7 +119,7 @@ def create_test(test_case):
|
||||
# tests.
|
||||
copts["tlsAllowInvalidHostnames"] = True
|
||||
|
||||
client = MongoClient(uri, **copts)
|
||||
client = PyMongoTestCase.unmanaged_simple_client(uri, **copts)
|
||||
if num_seeds is not None:
|
||||
self.assertEqual(len(client._topology_settings.seeds), num_seeds)
|
||||
if hosts is not None:
|
||||
@ -133,6 +132,7 @@ def create_test(test_case):
|
||||
client.admin.command("ping")
|
||||
# XXX: we should block until SRV poller runs at least once
|
||||
# and re-run these assertions.
|
||||
client.close()
|
||||
else:
|
||||
try:
|
||||
parse_uri(uri)
|
||||
@ -157,37 +157,37 @@ create_tests(TestDNSLoadBalanced)
|
||||
create_tests(TestDNSSharded)
|
||||
|
||||
|
||||
class TestParsingErrors(unittest.TestCase):
|
||||
class TestParsingErrors(PyMongoTestCase):
|
||||
def test_invalid_host(self):
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: mongodb is not",
|
||||
MongoClient,
|
||||
self.simple_client,
|
||||
"mongodb+srv://mongodb",
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: mongodb.com is not",
|
||||
MongoClient,
|
||||
self.simple_client,
|
||||
"mongodb+srv://mongodb.com",
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: an IP address is not",
|
||||
MongoClient,
|
||||
self.simple_client,
|
||||
"mongodb+srv://127.0.0.1",
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: an IP address is not",
|
||||
MongoClient,
|
||||
self.simple_client,
|
||||
"mongodb+srv://[::1]",
|
||||
)
|
||||
|
||||
|
||||
class TestCaseInsensitive(IntegrationTest):
|
||||
def test_connect_case_insensitive(self):
|
||||
client = MongoClient("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
|
||||
client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
|
||||
self.addCleanup(client.close)
|
||||
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ import warnings
|
||||
from test import IntegrationTest, PyMongoTestCase, client_context
|
||||
from test.test_bulk import BulkTestBase
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Mapping
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@ -53,6 +53,7 @@ from test.helpers import (
|
||||
KMIP_CREDS,
|
||||
LOCAL_MASTER_KEY,
|
||||
)
|
||||
from test.test_bulk import BulkTestBase
|
||||
from test.unified_format import generate_test_classes
|
||||
from test.utils import (
|
||||
AllowListEventListener,
|
||||
@ -61,7 +62,6 @@ from test.utils import (
|
||||
TopologyEventListener,
|
||||
camel_to_snake_args,
|
||||
is_greenthread_patched,
|
||||
rs_or_single_client,
|
||||
wait_until,
|
||||
)
|
||||
from test.utils_spec_runner import SpecRunner
|
||||
@ -109,13 +109,12 @@ class TestAutoEncryptionOpts(PyMongoTestCase):
|
||||
@unittest.skipUnless(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is not installed")
|
||||
def test_crypt_shared(self):
|
||||
# Test that we can pick up crypt_shared lib automatically
|
||||
client = MongoClient(
|
||||
self.simple_client(
|
||||
auto_encryption_opts=AutoEncryptionOpts(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", crypt_shared_lib_required=True
|
||||
),
|
||||
connect=False,
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
|
||||
@unittest.skipIf(_HAVE_PYMONGOCRYPT, "pymongocrypt is installed")
|
||||
def test_init_requires_pymongocrypt(self):
|
||||
@ -196,19 +195,16 @@ class TestAutoEncryptionOpts(PyMongoTestCase):
|
||||
|
||||
class TestClientOptions(PyMongoTestCase):
|
||||
def test_default(self):
|
||||
client = MongoClient(connect=False)
|
||||
self.addCleanup(client.close)
|
||||
client = self.simple_client(connect=False)
|
||||
self.assertEqual(get_client_opts(client).auto_encryption_opts, None)
|
||||
|
||||
client = MongoClient(auto_encryption_opts=None, connect=False)
|
||||
self.addCleanup(client.close)
|
||||
client = self.simple_client(auto_encryption_opts=None, connect=False)
|
||||
self.assertEqual(get_client_opts(client).auto_encryption_opts, None)
|
||||
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def test_kwargs(self):
|
||||
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
|
||||
client = MongoClient(auto_encryption_opts=opts, connect=False)
|
||||
self.addCleanup(client.close)
|
||||
client = self.simple_client(auto_encryption_opts=opts, connect=False)
|
||||
self.assertEqual(get_client_opts(client).auto_encryption_opts, opts)
|
||||
|
||||
|
||||
@ -229,6 +225,34 @@ class EncryptionIntegrationTest(IntegrationTest):
|
||||
self.assertIsInstance(val, Binary)
|
||||
self.assertEqual(val.subtype, UUID_SUBTYPE)
|
||||
|
||||
def create_client_encryption(
|
||||
self,
|
||||
kms_providers: Mapping[str, Any],
|
||||
key_vault_namespace: str,
|
||||
key_vault_client: MongoClient,
|
||||
codec_options: CodecOptions,
|
||||
kms_tls_options: Optional[Mapping[str, Any]] = None,
|
||||
):
|
||||
client_encryption = ClientEncryption(
|
||||
kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options
|
||||
)
|
||||
self.addCleanup(client_encryption.close)
|
||||
return client_encryption
|
||||
|
||||
@classmethod
|
||||
def unmanaged_create_client_encryption(
|
||||
cls,
|
||||
kms_providers: Mapping[str, Any],
|
||||
key_vault_namespace: str,
|
||||
key_vault_client: MongoClient,
|
||||
codec_options: CodecOptions,
|
||||
kms_tls_options: Optional[Mapping[str, Any]] = None,
|
||||
):
|
||||
client_encryption = ClientEncryption(
|
||||
kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options
|
||||
)
|
||||
return client_encryption
|
||||
|
||||
|
||||
# Location of JSON test files.
|
||||
if _IS_SYNC:
|
||||
@ -260,8 +284,7 @@ def bson_data(*paths):
|
||||
|
||||
class TestClientSimple(EncryptionIntegrationTest):
|
||||
def _test_auto_encrypt(self, opts):
|
||||
client = rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(auto_encryption_opts=opts)
|
||||
|
||||
# Create the encrypted field's data key.
|
||||
key_vault = create_key_vault(
|
||||
@ -342,8 +365,7 @@ class TestClientSimple(EncryptionIntegrationTest):
|
||||
|
||||
def test_use_after_close(self):
|
||||
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
|
||||
client = rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(auto_encryption_opts=opts)
|
||||
|
||||
client.admin.command("ping")
|
||||
client.close()
|
||||
@ -358,10 +380,10 @@ class TestClientSimple(EncryptionIntegrationTest):
|
||||
is_greenthread_patched(),
|
||||
"gevent and eventlet do not support POSIX-style forking.",
|
||||
)
|
||||
@client_context.require_sync
|
||||
def test_fork(self):
|
||||
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
|
||||
client = rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(auto_encryption_opts=opts)
|
||||
|
||||
def target():
|
||||
with warnings.catch_warnings():
|
||||
@ -375,8 +397,7 @@ class TestClientSimple(EncryptionIntegrationTest):
|
||||
class TestEncryptedBulkWrite(BulkTestBase, EncryptionIntegrationTest):
|
||||
def test_upsert_uuid_standard_encrypt(self):
|
||||
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
|
||||
client = rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(auto_encryption_opts=opts)
|
||||
|
||||
options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
|
||||
encrypted_coll = client.pymongo_test.test
|
||||
@ -416,8 +437,7 @@ class TestClientMaxWireVersion(IntegrationTest):
|
||||
@client_context.require_version_max(4, 0, 99)
|
||||
def test_raise_max_wire_version_error(self):
|
||||
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
|
||||
client = rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(auto_encryption_opts=opts)
|
||||
msg = "Auto-encryption requires a minimum MongoDB version of 4.2"
|
||||
with self.assertRaisesRegex(ConfigurationError, msg):
|
||||
client.test.test.insert_one({})
|
||||
@ -430,8 +450,7 @@ class TestClientMaxWireVersion(IntegrationTest):
|
||||
|
||||
def test_raise_unsupported_error(self):
|
||||
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
|
||||
client = rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(auto_encryption_opts=opts)
|
||||
msg = "find_raw_batches does not support auto encryption"
|
||||
with self.assertRaisesRegex(InvalidOperation, msg):
|
||||
client.test.test.find_raw_batches({})
|
||||
@ -450,10 +469,9 @@ class TestClientMaxWireVersion(IntegrationTest):
|
||||
|
||||
class TestExplicitSimple(EncryptionIntegrationTest):
|
||||
def test_encrypt_decrypt(self):
|
||||
client_encryption = ClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS
|
||||
)
|
||||
self.addCleanup(client_encryption.close)
|
||||
# Use standard UUID representation.
|
||||
key_vault = client_context.client.keyvault.get_collection("datakeys", codec_options=OPTS)
|
||||
self.addCleanup(key_vault.drop)
|
||||
@ -493,10 +511,9 @@ class TestExplicitSimple(EncryptionIntegrationTest):
|
||||
self.assertEqual(decrypted_ssn, doc["ssn"])
|
||||
|
||||
def test_validation(self):
|
||||
client_encryption = ClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS
|
||||
)
|
||||
self.addCleanup(client_encryption.close)
|
||||
|
||||
msg = "value to decrypt must be a bson.binary.Binary with subtype 6"
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
@ -510,10 +527,9 @@ class TestExplicitSimple(EncryptionIntegrationTest):
|
||||
client_encryption.encrypt("str", algo, key_id=Binary(b"123"))
|
||||
|
||||
def test_bson_errors(self):
|
||||
client_encryption = ClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS
|
||||
)
|
||||
self.addCleanup(client_encryption.close)
|
||||
|
||||
# Attempt to encrypt an unencodable object.
|
||||
unencodable_value = object()
|
||||
@ -526,7 +542,7 @@ class TestExplicitSimple(EncryptionIntegrationTest):
|
||||
|
||||
def test_codec_options(self):
|
||||
with self.assertRaisesRegex(TypeError, "codec_options must be"):
|
||||
ClientEncryption(
|
||||
self.create_client_encryption(
|
||||
KMS_PROVIDERS,
|
||||
"keyvault.datakeys",
|
||||
client_context.client,
|
||||
@ -534,10 +550,9 @@ class TestExplicitSimple(EncryptionIntegrationTest):
|
||||
)
|
||||
|
||||
opts = CodecOptions(uuid_representation=UuidRepresentation.JAVA_LEGACY)
|
||||
client_encryption_legacy = ClientEncryption(
|
||||
client_encryption_legacy = self.create_client_encryption(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", client_context.client, opts
|
||||
)
|
||||
self.addCleanup(client_encryption_legacy.close)
|
||||
|
||||
# Create the encrypted field's data key.
|
||||
key_id = client_encryption_legacy.create_data_key("local")
|
||||
@ -552,10 +567,9 @@ class TestExplicitSimple(EncryptionIntegrationTest):
|
||||
|
||||
# Encrypt the same UUID with STANDARD codec options.
|
||||
opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
|
||||
client_encryption = ClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", client_context.client, opts
|
||||
)
|
||||
self.addCleanup(client_encryption.close)
|
||||
encrypted_standard = client_encryption.encrypt(
|
||||
value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id
|
||||
)
|
||||
@ -571,7 +585,7 @@ class TestExplicitSimple(EncryptionIntegrationTest):
|
||||
self.assertNotEqual(client_encryption.decrypt(encrypted_legacy), value)
|
||||
|
||||
def test_close(self):
|
||||
client_encryption = ClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS
|
||||
)
|
||||
client_encryption.close()
|
||||
@ -587,7 +601,7 @@ class TestExplicitSimple(EncryptionIntegrationTest):
|
||||
client_encryption.decrypt(Binary(b"", 6))
|
||||
|
||||
def test_with_statement(self):
|
||||
with ClientEncryption(
|
||||
with self.create_client_encryption(
|
||||
KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS
|
||||
) as client_encryption:
|
||||
pass
|
||||
@ -807,7 +821,7 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.client.db.coll.drop()
|
||||
cls.vault = create_key_vault(cls.client.keyvault.datakeys)
|
||||
|
||||
@ -829,10 +843,10 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
|
||||
opts = AutoEncryptionOpts(
|
||||
cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS
|
||||
)
|
||||
cls.client_encrypted = rs_or_single_client(
|
||||
cls.client_encrypted = cls.unmanaged_rs_or_single_client(
|
||||
auto_encryption_opts=opts, uuidRepresentation="standard"
|
||||
)
|
||||
cls.client_encryption = ClientEncryption(
|
||||
cls.client_encryption = cls.unmanaged_create_client_encryption(
|
||||
cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS
|
||||
)
|
||||
|
||||
@ -919,8 +933,7 @@ class TestExternalKeyVault(EncryptionIntegrationTest):
|
||||
# Configure the encrypted field via the local schema_map option.
|
||||
schemas = {"db.coll": json_data("external", "external-schema.json")}
|
||||
if with_external_key_vault:
|
||||
key_vault_client = rs_or_single_client(username="fake-user", password="fake-pwd")
|
||||
self.addCleanup(key_vault_client.close)
|
||||
key_vault_client = self.rs_or_single_client(username="fake-user", password="fake-pwd")
|
||||
else:
|
||||
key_vault_client = client_context.client
|
||||
opts = AutoEncryptionOpts(
|
||||
@ -930,15 +943,13 @@ class TestExternalKeyVault(EncryptionIntegrationTest):
|
||||
key_vault_client=key_vault_client,
|
||||
)
|
||||
|
||||
client_encrypted = rs_or_single_client(
|
||||
client_encrypted = self.rs_or_single_client(
|
||||
auto_encryption_opts=opts, uuidRepresentation="standard"
|
||||
)
|
||||
self.addCleanup(client_encrypted.close)
|
||||
|
||||
client_encryption = ClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
self.kms_providers(), "keyvault.datakeys", key_vault_client, OPTS
|
||||
)
|
||||
self.addCleanup(client_encryption.close)
|
||||
|
||||
if with_external_key_vault:
|
||||
# Authentication error.
|
||||
@ -984,10 +995,9 @@ class TestViews(EncryptionIntegrationTest):
|
||||
self.addCleanup(self.client.db.view.drop)
|
||||
|
||||
opts = AutoEncryptionOpts(self.kms_providers(), "keyvault.datakeys")
|
||||
client_encrypted = rs_or_single_client(
|
||||
client_encrypted = self.rs_or_single_client(
|
||||
auto_encryption_opts=opts, uuidRepresentation="standard"
|
||||
)
|
||||
self.addCleanup(client_encrypted.close)
|
||||
|
||||
with self.assertRaisesRegex(EncryptionError, "cannot auto encrypt a view"):
|
||||
client_encrypted.db.view.insert_one({})
|
||||
@ -1044,17 +1054,15 @@ class TestCorpus(EncryptionIntegrationTest):
|
||||
)
|
||||
self.addCleanup(vault.drop)
|
||||
|
||||
client_encrypted = rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addCleanup(client_encrypted.close)
|
||||
client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts)
|
||||
|
||||
client_encryption = ClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
self.kms_providers(),
|
||||
"keyvault.datakeys",
|
||||
client_context.client,
|
||||
OPTS,
|
||||
kms_tls_options=KMS_TLS_OPTS,
|
||||
)
|
||||
self.addCleanup(client_encryption.close)
|
||||
|
||||
corpus = self.fix_up_curpus(json_data("corpus", "corpus.json"))
|
||||
corpus_copied: SON = SON()
|
||||
@ -1197,7 +1205,7 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
|
||||
|
||||
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client_encrypted = rs_or_single_client(
|
||||
cls.client_encrypted = cls.unmanaged_rs_or_single_client(
|
||||
auto_encryption_opts=opts, event_listeners=[cls.listener]
|
||||
)
|
||||
cls.coll_encrypted = cls.client_encrypted.db.coll
|
||||
@ -1285,7 +1293,7 @@ class TestCustomEndpoint(EncryptionIntegrationTest):
|
||||
"gcp": GCP_CREDS,
|
||||
"kmip": KMIP_CREDS,
|
||||
}
|
||||
self.client_encryption = ClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
kms_providers=kms_providers,
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
key_vault_client=client_context.client,
|
||||
@ -1297,7 +1305,7 @@ class TestCustomEndpoint(EncryptionIntegrationTest):
|
||||
kms_providers_invalid["azure"]["identityPlatformEndpoint"] = "doesnotexist.invalid:443"
|
||||
kms_providers_invalid["gcp"]["endpoint"] = "doesnotexist.invalid:443"
|
||||
kms_providers_invalid["kmip"]["endpoint"] = "doesnotexist.local:5698"
|
||||
self.client_encryption_invalid = ClientEncryption(
|
||||
self.client_encryption_invalid = self.create_client_encryption(
|
||||
kms_providers=kms_providers_invalid,
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
key_vault_client=client_context.client,
|
||||
@ -1476,7 +1484,7 @@ class TestCustomEndpoint(EncryptionIntegrationTest):
|
||||
self.client_encryption.create_data_key("kmip", key)
|
||||
|
||||
|
||||
class AzureGCPEncryptionTestMixin:
|
||||
class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
|
||||
DEK = None
|
||||
KMS_PROVIDER_MAP = None
|
||||
KEYVAULT_DB = "keyvault"
|
||||
@ -1488,7 +1496,7 @@ class AzureGCPEncryptionTestMixin:
|
||||
create_key_vault(keyvault, self.DEK)
|
||||
|
||||
def _test_explicit(self, expectation):
|
||||
client_encryption = ClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
|
||||
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
|
||||
client_context.client,
|
||||
@ -1517,7 +1525,7 @@ class AzureGCPEncryptionTestMixin:
|
||||
)
|
||||
|
||||
insert_listener = AllowListEventListener("insert")
|
||||
client = rs_or_single_client(
|
||||
client = self.rs_or_single_client(
|
||||
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
@ -1596,19 +1604,17 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
|
||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#deadlock-tests
|
||||
class TestDeadlockProse(EncryptionIntegrationTest):
|
||||
def setUp(self):
|
||||
self.client_test = rs_or_single_client(
|
||||
self.client_test = self.rs_or_single_client(
|
||||
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
|
||||
)
|
||||
self.addCleanup(self.client_test.close)
|
||||
|
||||
self.client_keyvault_listener = OvertCommandListener()
|
||||
self.client_keyvault = rs_or_single_client(
|
||||
self.client_keyvault = self.rs_or_single_client(
|
||||
maxPoolSize=1,
|
||||
readConcernLevel="majority",
|
||||
w="majority",
|
||||
event_listeners=[self.client_keyvault_listener],
|
||||
)
|
||||
self.addCleanup(self.client_keyvault.close)
|
||||
|
||||
self.client_test.keyvault.datakeys.drop()
|
||||
self.client_test.db.coll.drop()
|
||||
@ -1619,7 +1625,7 @@ class TestDeadlockProse(EncryptionIntegrationTest):
|
||||
codec_options=OPTS,
|
||||
)
|
||||
|
||||
client_encryption = ClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
kms_providers={"local": {"key": LOCAL_MASTER_KEY}},
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
key_vault_client=self.client_test,
|
||||
@ -1635,7 +1641,7 @@ class TestDeadlockProse(EncryptionIntegrationTest):
|
||||
self.optargs = ({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
|
||||
|
||||
def _run_test(self, max_pool_size, auto_encryption_opts):
|
||||
client_encrypted = rs_or_single_client(
|
||||
client_encrypted = self.rs_or_single_client(
|
||||
readConcernLevel="majority",
|
||||
w="majority",
|
||||
maxPoolSize=max_pool_size,
|
||||
@ -1653,8 +1659,6 @@ class TestDeadlockProse(EncryptionIntegrationTest):
|
||||
result = client_encrypted.db.coll.find_one({"_id": 0})
|
||||
self.assertEqual(result, {"_id": 0, "encrypted": "string0"})
|
||||
|
||||
self.addCleanup(client_encrypted.close)
|
||||
|
||||
def test_case_1(self):
|
||||
self._run_test(
|
||||
max_pool_size=1,
|
||||
@ -1830,7 +1834,7 @@ class TestDecryptProse(EncryptionIntegrationTest):
|
||||
create_key_vault(self.client.keyvault.datakeys)
|
||||
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
|
||||
|
||||
self.client_encryption = ClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
kms_providers_map, "keyvault.datakeys", self.client, CodecOptions()
|
||||
)
|
||||
keyID = self.client_encryption.create_data_key("local")
|
||||
@ -1845,10 +1849,9 @@ class TestDecryptProse(EncryptionIntegrationTest):
|
||||
key_vault_namespace="keyvault.datakeys", kms_providers=kms_providers_map
|
||||
)
|
||||
self.listener = AllowListEventListener("aggregate")
|
||||
self.encrypted_client = rs_or_single_client(
|
||||
self.encrypted_client = self.rs_or_single_client(
|
||||
auto_encryption_opts=opts, retryReads=False, event_listeners=[self.listener]
|
||||
)
|
||||
self.addCleanup(self.encrypted_client.close)
|
||||
|
||||
def test_01_command_error(self):
|
||||
with self.fail_point(
|
||||
@ -1925,8 +1928,7 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest):
|
||||
"--port=27027",
|
||||
],
|
||||
)
|
||||
client_encrypted = rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addCleanup(client_encrypted.close)
|
||||
client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts)
|
||||
with self.assertRaisesRegex(EncryptionError, "Timeout"):
|
||||
client_encrypted.db.coll.insert_one({"encrypted": "test"})
|
||||
|
||||
@ -1940,11 +1942,12 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest):
|
||||
"--port=27027",
|
||||
],
|
||||
)
|
||||
client_encrypted = rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addCleanup(client_encrypted.close)
|
||||
client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts)
|
||||
client_encrypted.db.coll.insert_one({"unencrypted": "test"})
|
||||
# Validate that mongocryptd was not spawned:
|
||||
mongocryptd_client = MongoClient("mongodb://localhost:27027/?serverSelectionTimeoutMS=500")
|
||||
mongocryptd_client = self.simple_client(
|
||||
"mongodb://localhost:27027/?serverSelectionTimeoutMS=500"
|
||||
)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
mongocryptd_client.admin.command("ping")
|
||||
|
||||
@ -1966,15 +1969,13 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest):
|
||||
],
|
||||
crypt_shared_lib_required=True,
|
||||
)
|
||||
client_encrypted = rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addCleanup(client_encrypted.close)
|
||||
client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts)
|
||||
client_encrypted.db.coll.drop()
|
||||
client_encrypted.db.coll.insert_one({"encrypted": "test"})
|
||||
self.assertEncrypted((client_context.client.db.coll.find_one({}))["encrypted"])
|
||||
no_mongocryptd_client = MongoClient(
|
||||
no_mongocryptd_client = self.simple_client(
|
||||
host="mongodb://localhost:47021/db?serverSelectionTimeoutMS=1000"
|
||||
)
|
||||
self.addCleanup(no_mongocryptd_client.close)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
no_mongocryptd_client.db.command("ping")
|
||||
|
||||
@ -2008,8 +2009,7 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest):
|
||||
mongocryptd_uri="mongodb://localhost:47021",
|
||||
crypt_shared_lib_required=False,
|
||||
)
|
||||
client_encrypted = rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addCleanup(client_encrypted.close)
|
||||
client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts)
|
||||
client_encrypted.db.coll.drop()
|
||||
client_encrypted.db.coll.insert_one({"encrypted": "test"})
|
||||
server.shutdown()
|
||||
@ -2023,10 +2023,9 @@ class TestKmsTLSProse(EncryptionIntegrationTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.patch_system_certs(CA_PEM)
|
||||
self.client_encrypted = ClientEncryption(
|
||||
self.client_encrypted = self.create_client_encryption(
|
||||
{"aws": AWS_CREDS}, "keyvault.datakeys", self.client, OPTS
|
||||
)
|
||||
self.addCleanup(self.client_encrypted.close)
|
||||
|
||||
def test_invalid_kms_certificate_expired(self):
|
||||
key = {
|
||||
@ -2071,36 +2070,32 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
|
||||
"gcp": {"tlsCAFile": CA_PEM},
|
||||
"kmip": {"tlsCAFile": CA_PEM},
|
||||
}
|
||||
self.client_encryption_no_client_cert = ClientEncryption(
|
||||
self.client_encryption_no_client_cert = self.create_client_encryption(
|
||||
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only
|
||||
)
|
||||
self.addCleanup(self.client_encryption_no_client_cert.close)
|
||||
# 2, same providers as above but with tlsCertificateKeyFile.
|
||||
kms_tls_opts = copy.deepcopy(kms_tls_opts_ca_only)
|
||||
for p in kms_tls_opts:
|
||||
kms_tls_opts[p]["tlsCertificateKeyFile"] = CLIENT_PEM
|
||||
self.client_encryption_with_tls = ClientEncryption(
|
||||
self.client_encryption_with_tls = self.create_client_encryption(
|
||||
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts
|
||||
)
|
||||
self.addCleanup(self.client_encryption_with_tls.close)
|
||||
# 3, update endpoints to expired host.
|
||||
providers: dict = copy.deepcopy(providers)
|
||||
providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9000"
|
||||
providers["gcp"]["endpoint"] = "127.0.0.1:9000"
|
||||
providers["kmip"]["endpoint"] = "127.0.0.1:9000"
|
||||
self.client_encryption_expired = ClientEncryption(
|
||||
self.client_encryption_expired = self.create_client_encryption(
|
||||
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only
|
||||
)
|
||||
self.addCleanup(self.client_encryption_expired.close)
|
||||
# 3, update endpoints to invalid host.
|
||||
providers: dict = copy.deepcopy(providers)
|
||||
providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9001"
|
||||
providers["gcp"]["endpoint"] = "127.0.0.1:9001"
|
||||
providers["kmip"]["endpoint"] = "127.0.0.1:9001"
|
||||
self.client_encryption_invalid_hostname = ClientEncryption(
|
||||
self.client_encryption_invalid_hostname = self.create_client_encryption(
|
||||
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only
|
||||
)
|
||||
self.addCleanup(self.client_encryption_invalid_hostname.close)
|
||||
# Errors when client has no cert, some examples:
|
||||
# [SSL: TLSV13_ALERT_CERTIFICATE_REQUIRED] tlsv13 alert certificate required (_ssl.c:2623)
|
||||
self.cert_error = (
|
||||
@ -2138,7 +2133,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
|
||||
"gcp:with_tls": with_cert,
|
||||
"kmip:with_tls": with_cert,
|
||||
}
|
||||
self.client_encryption_with_names = ClientEncryption(
|
||||
self.client_encryption_with_names = self.create_client_encryption(
|
||||
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_4
|
||||
)
|
||||
|
||||
@ -2220,10 +2215,9 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
|
||||
def test_05_tlsDisableOCSPEndpointCheck_is_permitted(self):
|
||||
providers = {"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}}
|
||||
options = {"aws": {"tlsDisableOCSPEndpointCheck": True}}
|
||||
encryption = ClientEncryption(
|
||||
encryption = self.create_client_encryption(
|
||||
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options
|
||||
)
|
||||
self.addCleanup(encryption.close)
|
||||
ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"]
|
||||
if not hasattr(ctx, "check_ocsp_endpoint"):
|
||||
raise self.skipTest("OCSP not enabled")
|
||||
@ -2273,7 +2267,7 @@ class TestUniqueIndexOnKeyAltNamesProse(EncryptionIntegrationTest):
|
||||
self.client = client_context.client
|
||||
create_key_vault(self.client.keyvault.datakeys)
|
||||
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
|
||||
self.client_encryption = ClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
kms_providers_map, "keyvault.datakeys", self.client, CodecOptions()
|
||||
)
|
||||
self.def_key_id = self.client_encryption.create_data_key("local", key_alt_names=["def"])
|
||||
@ -2311,17 +2305,15 @@ class TestExplicitQueryableEncryption(EncryptionIntegrationTest):
|
||||
key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document)
|
||||
self.addCleanup(key_vault.drop)
|
||||
self.key_vault_client = self.client
|
||||
self.client_encryption = ClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
{"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS
|
||||
)
|
||||
self.addCleanup(self.client_encryption.close)
|
||||
opts = AutoEncryptionOpts(
|
||||
{"local": {"key": LOCAL_MASTER_KEY}},
|
||||
key_vault.full_name,
|
||||
bypass_query_analysis=True,
|
||||
)
|
||||
self.encrypted_client = rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addCleanup(self.encrypted_client.close)
|
||||
self.encrypted_client = self.rs_or_single_client(auto_encryption_opts=opts)
|
||||
|
||||
def test_01_insert_encrypted_indexed_and_find(self):
|
||||
val = "encrypted indexed value"
|
||||
@ -2444,14 +2436,13 @@ class TestRewrapWithSeparateClientEncryption(EncryptionIntegrationTest):
|
||||
self.client.keyvault.drop_collection("datakeys")
|
||||
|
||||
# Step 2. Create a ``ClientEncryption`` object named ``client_encryption1``
|
||||
client_encryption1 = ClientEncryption(
|
||||
client_encryption1 = self.create_client_encryption(
|
||||
key_vault_client=self.client,
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
kms_providers=ALL_KMS_PROVIDERS,
|
||||
kms_tls_options=KMS_TLS_OPTS,
|
||||
codec_options=OPTS,
|
||||
)
|
||||
self.addCleanup(client_encryption1.close)
|
||||
|
||||
# Step 3. Call ``client_encryption1.create_data_key`` with ``src_provider``.
|
||||
key_id = client_encryption1.create_data_key(
|
||||
@ -2464,16 +2455,14 @@ class TestRewrapWithSeparateClientEncryption(EncryptionIntegrationTest):
|
||||
)
|
||||
|
||||
# Step 5. Create a ``ClientEncryption`` object named ``client_encryption2``
|
||||
client2 = rs_or_single_client()
|
||||
self.addCleanup(client2.close)
|
||||
client_encryption2 = ClientEncryption(
|
||||
client2 = self.rs_or_single_client()
|
||||
client_encryption2 = self.create_client_encryption(
|
||||
key_vault_client=client2,
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
kms_providers=ALL_KMS_PROVIDERS,
|
||||
kms_tls_options=KMS_TLS_OPTS,
|
||||
codec_options=OPTS,
|
||||
)
|
||||
self.addCleanup(client_encryption2.close)
|
||||
|
||||
# Step 6. Call ``client_encryption2.rewrap_many_data_key`` with an empty ``filter``.
|
||||
rewrap_many_data_key_result = client_encryption2.rewrap_many_data_key(
|
||||
@ -2508,7 +2497,7 @@ class TestOnDemandAWSCredentials(EncryptionIntegrationTest):
|
||||
|
||||
@unittest.skipIf(any(AWS_CREDS.values()), "AWS environment credentials are set")
|
||||
def test_01_failure(self):
|
||||
self.client_encryption = ClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
kms_providers={"aws": {}},
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
key_vault_client=client_context.client,
|
||||
@ -2519,7 +2508,7 @@ class TestOnDemandAWSCredentials(EncryptionIntegrationTest):
|
||||
|
||||
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
|
||||
def test_02_success(self):
|
||||
self.client_encryption = ClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
kms_providers={"aws": {}},
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
key_vault_client=client_context.client,
|
||||
@ -2539,8 +2528,7 @@ class TestQueryableEncryptionDocsExample(EncryptionIntegrationTest):
|
||||
# MongoClient to use in testing that handles auth/tls/etc,
|
||||
# and cleanup.
|
||||
def MongoClient(**kwargs):
|
||||
c = rs_or_single_client(**kwargs)
|
||||
self.addCleanup(c.close)
|
||||
c = self.rs_or_single_client(**kwargs)
|
||||
return c
|
||||
|
||||
# Drop data from prior test runs.
|
||||
@ -2551,7 +2539,7 @@ class TestQueryableEncryptionDocsExample(EncryptionIntegrationTest):
|
||||
|
||||
# Create two data keys.
|
||||
key_vault_client = MongoClient()
|
||||
client_encryption = ClientEncryption(
|
||||
client_encryption = self.create_client_encryption(
|
||||
kms_providers_map, "keyvault.datakeys", key_vault_client, CodecOptions()
|
||||
)
|
||||
key1_id = client_encryption.create_data_key("local")
|
||||
@ -2632,18 +2620,16 @@ class TestRangeQueryProse(EncryptionIntegrationTest):
|
||||
key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document)
|
||||
self.addCleanup(key_vault.drop)
|
||||
self.key_vault_client = self.client
|
||||
self.client_encryption = ClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
{"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS
|
||||
)
|
||||
self.addCleanup(self.client_encryption.close)
|
||||
opts = AutoEncryptionOpts(
|
||||
{"local": {"key": LOCAL_MASTER_KEY}},
|
||||
key_vault.full_name,
|
||||
bypass_query_analysis=True,
|
||||
)
|
||||
self.encrypted_client = rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.encrypted_client = self.rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.db = self.encrypted_client.db
|
||||
self.addCleanup(self.encrypted_client.close)
|
||||
|
||||
def run_expression_find(
|
||||
self, name, expression, expected_elems, range_opts, use_expr=False, key_id=None
|
||||
@ -2838,10 +2824,9 @@ class TestRangeQueryDefaultsProse(EncryptionIntegrationTest):
|
||||
super().setUp()
|
||||
self.client.drop_database(self.db)
|
||||
self.key_vault_client = self.client
|
||||
self.client_encryption = ClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
{"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys", self.key_vault_client, OPTS
|
||||
)
|
||||
self.addCleanup(self.client_encryption.close)
|
||||
self.key_id = self.client_encryption.create_data_key("local")
|
||||
opts = RangeOpts(min=0, max=1000)
|
||||
self.payload_defaults = self.client_encryption.encrypt(
|
||||
@ -2874,13 +2859,12 @@ class TestAutomaticDecryptionKeys(EncryptionIntegrationTest):
|
||||
self.client.drop_database(self.db)
|
||||
self.key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document)
|
||||
self.addCleanup(self.key_vault.drop)
|
||||
self.client_encryption = ClientEncryption(
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
{"local": {"key": LOCAL_MASTER_KEY}},
|
||||
self.key_vault.full_name,
|
||||
self.client,
|
||||
OPTS,
|
||||
)
|
||||
self.addCleanup(self.client_encryption.close)
|
||||
|
||||
def test_01_simple_create(self):
|
||||
coll, _ = self.client_encryption.create_encrypted_collection(
|
||||
@ -3096,10 +3080,9 @@ class TestNoSessionsSupport(EncryptionIntegrationTest):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.listener = OvertCommandListener()
|
||||
self.mongocryptd_client = MongoClient(
|
||||
self.mongocryptd_client = self.simple_client(
|
||||
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]
|
||||
)
|
||||
self.addCleanup(self.mongocryptd_client.close)
|
||||
|
||||
hello = self.mongocryptd_client.db.command("hello")
|
||||
self.assertNotIn("logicalSessionTimeoutMinutes", hello)
|
||||
|
||||
@ -22,7 +22,7 @@ import threading
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test.utils import rs_client, wait_until
|
||||
from test.utils import wait_until
|
||||
|
||||
import pymongo
|
||||
from pymongo.errors import ConnectionFailure, OperationFailure
|
||||
@ -1128,7 +1128,7 @@ class TestTransactionExamples(IntegrationTest):
|
||||
self.assertEqual(employee["status"], "Inactive")
|
||||
|
||||
def MongoClient(_):
|
||||
return rs_client()
|
||||
return self.rs_client()
|
||||
|
||||
uriString = None
|
||||
|
||||
@ -1220,7 +1220,7 @@ class TestVersionedApiExamples(IntegrationTest):
|
||||
def test_versioned_api(self):
|
||||
# Versioned API examples
|
||||
def MongoClient(_, server_api):
|
||||
return rs_client(server_api=server_api, connect=False)
|
||||
return self.rs_client(server_api=server_api, connect=False)
|
||||
|
||||
uri = None
|
||||
|
||||
@ -1251,7 +1251,7 @@ class TestVersionedApiExamples(IntegrationTest):
|
||||
):
|
||||
self.skipTest("This test needs MongoDB 5.0.2 or newer")
|
||||
|
||||
client = rs_client(server_api=ServerApi("1", strict=True))
|
||||
client = self.rs_client(server_api=ServerApi("1", strict=True))
|
||||
client.db.sales.drop()
|
||||
|
||||
# Start Versioned API Example 5
|
||||
|
||||
@ -33,7 +33,7 @@ from pymongo.synchronous.database import Database
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.utils import EventListener, rs_or_single_client
|
||||
from test.utils import EventListener
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from gridfs.errors import NoFile
|
||||
@ -790,7 +790,7 @@ Bye"""
|
||||
outfile.readchunk()
|
||||
|
||||
def test_grid_in_lazy_connect(self):
|
||||
client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10)
|
||||
client = self.simple_client("badhost", connect=False, serverSelectionTimeoutMS=10)
|
||||
fs = client.db.fs
|
||||
infile = GridIn(fs, file_id=-1, chunk_size=1)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
@ -801,7 +801,7 @@ Bye"""
|
||||
def test_unacknowledged(self):
|
||||
# w=0 is prohibited.
|
||||
with self.assertRaises(ConfigurationError):
|
||||
GridIn((rs_or_single_client(w=0)).pymongo_test.fs)
|
||||
GridIn((self.rs_or_single_client(w=0)).pymongo_test.fs)
|
||||
|
||||
def test_survive_cursor_not_found(self):
|
||||
# By default the find command returns 101 documents in the first batch.
|
||||
@ -809,7 +809,7 @@ Bye"""
|
||||
chunk_size = 1024
|
||||
data = b"d" * (102 * chunk_size)
|
||||
listener = EventListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
db = client.pymongo_test
|
||||
with GridIn(db.fs, chunk_size=chunk_size) as infile:
|
||||
infile.write(data)
|
||||
|
||||
@ -26,7 +26,7 @@ from unittest.mock import patch
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test.utils import joinall, one, rs_client, rs_or_single_client, single_client
|
||||
from test.utils import joinall, one
|
||||
|
||||
import gridfs
|
||||
from bson.binary import Binary
|
||||
@ -411,7 +411,7 @@ class TestGridfs(IntegrationTest):
|
||||
self.assertTrue(iterate_file(f))
|
||||
|
||||
def test_gridfs_lazy_connect(self):
|
||||
client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10)
|
||||
client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=10)
|
||||
db = client.db
|
||||
gfs = gridfs.GridFS(db)
|
||||
self.assertRaises(ServerSelectionTimeoutError, gfs.list)
|
||||
@ -492,7 +492,7 @@ class TestGridfs(IntegrationTest):
|
||||
def test_unacknowledged(self):
|
||||
# w=0 is prohibited.
|
||||
with self.assertRaises(ConfigurationError):
|
||||
gridfs.GridFS(rs_or_single_client(w=0).pymongo_test)
|
||||
gridfs.GridFS(self.rs_or_single_client(w=0).pymongo_test)
|
||||
|
||||
def test_md5(self):
|
||||
gin = self.fs.new_file()
|
||||
@ -519,7 +519,7 @@ class TestGridfsReplicaSet(IntegrationTest):
|
||||
client_context.client.drop_database("gfsreplica")
|
||||
|
||||
def test_gridfs_replica_set(self):
|
||||
rsc = rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY)
|
||||
rsc = self.rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY)
|
||||
|
||||
fs = gridfs.GridFS(rsc.gfsreplica, "gfsreplicatest")
|
||||
|
||||
@ -532,7 +532,7 @@ class TestGridfsReplicaSet(IntegrationTest):
|
||||
|
||||
def test_gridfs_secondary(self):
|
||||
secondary_host, secondary_port = one(self.client.secondaries)
|
||||
secondary_connection = single_client(
|
||||
secondary_connection = self.single_client(
|
||||
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY
|
||||
)
|
||||
|
||||
@ -547,7 +547,7 @@ class TestGridfsReplicaSet(IntegrationTest):
|
||||
# Should detect it's connected to secondary and not attempt to
|
||||
# create index.
|
||||
secondary_host, secondary_port = one(self.client.secondaries)
|
||||
client = single_client(
|
||||
client = self.single_client(
|
||||
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False
|
||||
)
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ from unittest.mock import patch
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test.utils import joinall, one, rs_client, rs_or_single_client, single_client
|
||||
from test.utils import joinall, one
|
||||
|
||||
import gridfs
|
||||
from bson.binary import Binary
|
||||
@ -345,7 +345,7 @@ class TestGridfs(IntegrationTest):
|
||||
self.assertTrue(iterate_file(fstr))
|
||||
|
||||
def test_gridfs_lazy_connect(self):
|
||||
client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=0)
|
||||
client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=0)
|
||||
cdb = client.db
|
||||
gfs = gridfs.GridFSBucket(cdb)
|
||||
self.assertRaises(ServerSelectionTimeoutError, gfs.delete, 0)
|
||||
@ -391,7 +391,7 @@ class TestGridfs(IntegrationTest):
|
||||
def test_unacknowledged(self):
|
||||
# w=0 is prohibited.
|
||||
with self.assertRaises(ConfigurationError):
|
||||
gridfs.GridFSBucket(rs_or_single_client(w=0).pymongo_test)
|
||||
gridfs.GridFSBucket(self.rs_or_single_client(w=0).pymongo_test)
|
||||
|
||||
def test_rename(self):
|
||||
_id = self.fs.upload_from_stream("first_name", b"testing")
|
||||
@ -489,7 +489,7 @@ class TestGridfsBucketReplicaSet(IntegrationTest):
|
||||
client_context.client.drop_database("gfsbucketreplica")
|
||||
|
||||
def test_gridfs_replica_set(self):
|
||||
rsc = rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY)
|
||||
rsc = self.rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY)
|
||||
|
||||
gfs = gridfs.GridFSBucket(rsc.gfsbucketreplica, "gfsbucketreplicatest")
|
||||
oid = gfs.upload_from_stream("test_filename", b"foo")
|
||||
@ -498,7 +498,7 @@ class TestGridfsBucketReplicaSet(IntegrationTest):
|
||||
|
||||
def test_gridfs_secondary(self):
|
||||
secondary_host, secondary_port = one(self.client.secondaries)
|
||||
secondary_connection = single_client(
|
||||
secondary_connection = self.single_client(
|
||||
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY
|
||||
)
|
||||
|
||||
@ -513,7 +513,7 @@ class TestGridfsBucketReplicaSet(IntegrationTest):
|
||||
# Should detect it's connected to secondary and not attempt to
|
||||
# create index.
|
||||
secondary_host, secondary_port = one(self.client.secondaries)
|
||||
client = single_client(
|
||||
client = self.single_client(
|
||||
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False
|
||||
)
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_knobs, unittest
|
||||
from test.utils import HeartbeatEventListener, MockPool, single_client, wait_until
|
||||
from test.utils import HeartbeatEventListener, MockPool, wait_until
|
||||
|
||||
from pymongo.errors import ConnectionFailure
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
@ -40,7 +40,7 @@ class TestHeartbeatMonitoring(IntegrationTest):
|
||||
raise responses[1]
|
||||
return Hello(responses[1]), 99
|
||||
|
||||
m = single_client(
|
||||
m = self.single_client(
|
||||
h=uri, event_listeners=(listener,), _monitor_class=MockMonitor, _pool_class=MockPool
|
||||
)
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
from test.utils import ExceptionCatchingThread, get_pool, rs_client, wait_until
|
||||
from test.utils import ExceptionCatchingThread, get_pool, wait_until
|
||||
|
||||
pytestmark = pytest.mark.load_balancer
|
||||
|
||||
@ -54,7 +54,7 @@ class TestLB(IntegrationTest):
|
||||
|
||||
@client_context.require_load_balancer
|
||||
def test_unpin_committed_transaction(self):
|
||||
client = rs_client()
|
||||
client = self.rs_client()
|
||||
self.addCleanup(client.close)
|
||||
pool = get_pool(client)
|
||||
coll = client[self.db.name].test
|
||||
@ -85,7 +85,7 @@ class TestLB(IntegrationTest):
|
||||
self._test_no_gc_deadlock(create_resource)
|
||||
|
||||
def _test_no_gc_deadlock(self, create_resource):
|
||||
client = rs_client()
|
||||
client = self.rs_client()
|
||||
self.addCleanup(client.close)
|
||||
pool = get_pool(client)
|
||||
coll = client[self.db.name].test
|
||||
@ -124,7 +124,7 @@ class TestLB(IntegrationTest):
|
||||
|
||||
@client_context.require_transactions
|
||||
def test_session_gc(self):
|
||||
client = rs_client()
|
||||
client = self.rs_client()
|
||||
self.addCleanup(client.close)
|
||||
pool = get_pool(client)
|
||||
session = client.start_session()
|
||||
|
||||
@ -15,7 +15,6 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from test import IntegrationTest, unittest
|
||||
from test.utils import single_client
|
||||
from unittest.mock import patch
|
||||
|
||||
from bson import json_util
|
||||
@ -85,7 +84,7 @@ class TestLogger(IntegrationTest):
|
||||
self.assertEqual(last_3_bytes, str_to_repeat)
|
||||
|
||||
def test_logging_without_listeners(self):
|
||||
c = single_client()
|
||||
c = self.single_client()
|
||||
self.assertEqual(len(c._event_listeners.event_listeners()), 0)
|
||||
with self.assertLogs("pymongo.connection", level="DEBUG") as cm:
|
||||
c.db.test.insert_one({"x": "1"})
|
||||
|
||||
@ -20,15 +20,14 @@ import sys
|
||||
import time
|
||||
import warnings
|
||||
|
||||
from pymongo import MongoClient
|
||||
from pymongo.operations import _Op
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import client_context, unittest
|
||||
from test.utils import rs_or_single_client
|
||||
from test import PyMongoTestCase, client_context, unittest
|
||||
from test.utils_selection_tests import create_selection_tests
|
||||
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
|
||||
@ -40,54 +39,58 @@ class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
class TestMaxStaleness(unittest.TestCase):
|
||||
class TestMaxStaleness(PyMongoTestCase):
|
||||
def test_max_staleness(self):
|
||||
client = MongoClient()
|
||||
client = self.simple_client()
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
|
||||
client = MongoClient("mongodb://a/?readPreference=secondary")
|
||||
client = self.simple_client("mongodb://a/?readPreference=secondary")
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
|
||||
# These tests are specified in max-staleness-tests.rst.
|
||||
with self.assertRaises(ConfigurationError):
|
||||
# Default read pref "primary" can't be used with max staleness.
|
||||
MongoClient("mongodb://a/?maxStalenessSeconds=120")
|
||||
self.simple_client("mongodb://a/?maxStalenessSeconds=120")
|
||||
|
||||
with self.assertRaises(ConfigurationError):
|
||||
# Read pref "primary" can't be used with max staleness.
|
||||
MongoClient("mongodb://a/?readPreference=primary&maxStalenessSeconds=120")
|
||||
self.simple_client("mongodb://a/?readPreference=primary&maxStalenessSeconds=120")
|
||||
|
||||
client = MongoClient("mongodb://host/?maxStalenessSeconds=-1")
|
||||
client = self.simple_client("mongodb://host/?maxStalenessSeconds=-1")
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
|
||||
client = MongoClient("mongodb://host/?readPreference=primary&maxStalenessSeconds=-1")
|
||||
client = self.simple_client("mongodb://host/?readPreference=primary&maxStalenessSeconds=-1")
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
|
||||
client = MongoClient("mongodb://host/?readPreference=secondary&maxStalenessSeconds=120")
|
||||
client = self.simple_client(
|
||||
"mongodb://host/?readPreference=secondary&maxStalenessSeconds=120"
|
||||
)
|
||||
self.assertEqual(120, client.read_preference.max_staleness)
|
||||
|
||||
client = MongoClient("mongodb://a/?readPreference=secondary&maxStalenessSeconds=1")
|
||||
client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=1")
|
||||
self.assertEqual(1, client.read_preference.max_staleness)
|
||||
|
||||
client = MongoClient("mongodb://a/?readPreference=secondary&maxStalenessSeconds=-1")
|
||||
client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=-1")
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
|
||||
client = MongoClient(maxStalenessSeconds=-1, readPreference="nearest")
|
||||
client = self.simple_client(maxStalenessSeconds=-1, readPreference="nearest")
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
# Prohibit None.
|
||||
MongoClient(maxStalenessSeconds=None, readPreference="nearest")
|
||||
self.simple_client(maxStalenessSeconds=None, readPreference="nearest")
|
||||
|
||||
def test_max_staleness_float(self):
|
||||
with self.assertRaises(TypeError) as ctx:
|
||||
rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest")
|
||||
self.rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest")
|
||||
|
||||
self.assertIn("must be an integer", str(ctx.exception))
|
||||
|
||||
with warnings.catch_warnings(record=True) as ctx:
|
||||
warnings.simplefilter("always")
|
||||
client = MongoClient("mongodb://host/?maxStalenessSeconds=1.5&readPreference=nearest")
|
||||
client = self.simple_client(
|
||||
"mongodb://host/?maxStalenessSeconds=1.5&readPreference=nearest"
|
||||
)
|
||||
|
||||
# Option was ignored.
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
@ -96,13 +99,15 @@ class TestMaxStaleness(unittest.TestCase):
|
||||
def test_max_staleness_zero(self):
|
||||
# Zero is too small.
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest")
|
||||
self.rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest")
|
||||
|
||||
self.assertIn("must be a positive integer", str(ctx.exception))
|
||||
|
||||
with warnings.catch_warnings(record=True) as ctx:
|
||||
warnings.simplefilter("always")
|
||||
client = MongoClient("mongodb://host/?maxStalenessSeconds=0&readPreference=nearest")
|
||||
client = self.simple_client(
|
||||
"mongodb://host/?maxStalenessSeconds=0&readPreference=nearest"
|
||||
)
|
||||
|
||||
# Option was ignored.
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
@ -111,7 +116,7 @@ class TestMaxStaleness(unittest.TestCase):
|
||||
@client_context.require_replica_set
|
||||
def test_last_write_date(self):
|
||||
# From max-staleness-tests.rst, "Parse lastWriteDate".
|
||||
client = rs_or_single_client(heartbeatFrequencyMS=500)
|
||||
client = self.rs_or_single_client(heartbeatFrequencyMS=500)
|
||||
client.pymongo_test.test.insert_one({})
|
||||
# Wait for the server description to be updated.
|
||||
time.sleep(1)
|
||||
|
||||
@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
import gc
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
@ -25,7 +26,6 @@ sys.path[0:0] = [""]
|
||||
from test import IntegrationTest, connected, unittest
|
||||
from test.utils import (
|
||||
ServerAndTopologyEventListener,
|
||||
single_client,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
@ -47,30 +47,31 @@ def get_executors(client):
|
||||
return [e for e in executors if e is not None]
|
||||
|
||||
|
||||
def create_client():
|
||||
listener = ServerAndTopologyEventListener()
|
||||
client = single_client(event_listeners=[listener])
|
||||
connected(client)
|
||||
return client
|
||||
|
||||
|
||||
class TestMonitor(IntegrationTest):
|
||||
def create_client(self):
|
||||
listener = ServerAndTopologyEventListener()
|
||||
client = self.unmanaged_single_client(event_listeners=[listener])
|
||||
connected(client)
|
||||
return client
|
||||
|
||||
def test_cleanup_executors_on_client_del(self):
|
||||
client = create_client()
|
||||
executors = get_executors(client)
|
||||
self.assertEqual(len(executors), 4)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
client = self.create_client()
|
||||
executors = get_executors(client)
|
||||
self.assertEqual(len(executors), 4)
|
||||
|
||||
# Each executor stores a weakref to itself in _EXECUTORS.
|
||||
executor_refs = [(r, r()._name) for r in _EXECUTORS.copy() if r() in executors]
|
||||
# Each executor stores a weakref to itself in _EXECUTORS.
|
||||
executor_refs = [(r, r()._name) for r in _EXECUTORS.copy() if r() in executors]
|
||||
|
||||
del executors
|
||||
del client
|
||||
del executors
|
||||
del client
|
||||
|
||||
for ref, name in executor_refs:
|
||||
wait_until(partial(unregistered, ref), f"unregister executor: {name}", timeout=5)
|
||||
for ref, name in executor_refs:
|
||||
wait_until(partial(unregistered, ref), f"unregister executor: {name}", timeout=5)
|
||||
|
||||
def test_cleanup_executors_on_client_close(self):
|
||||
client = create_client()
|
||||
client = self.create_client()
|
||||
executors = get_executors(client)
|
||||
self.assertEqual(len(executors), 4)
|
||||
|
||||
|
||||
@ -31,8 +31,6 @@ from test import (
|
||||
)
|
||||
from test.utils import (
|
||||
EventListener,
|
||||
rs_or_single_client,
|
||||
single_client,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
@ -57,7 +55,9 @@ class TestCommandMonitoring(IntegrationTest):
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.listener = EventListener()
|
||||
cls.client = rs_or_single_client(event_listeners=[cls.listener], retryWrites=False)
|
||||
cls.client = cls.unmanaged_rs_or_single_client(
|
||||
event_listeners=[cls.listener], retryWrites=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
@ -405,7 +405,7 @@ class TestCommandMonitoring(IntegrationTest):
|
||||
@client_context.require_secondaries_count(1)
|
||||
def test_not_primary_error(self):
|
||||
address = next(iter(client_context.client.secondaries))
|
||||
client = single_client(*address, event_listeners=[self.listener])
|
||||
client = self.single_client(*address, event_listeners=[self.listener])
|
||||
# Clear authentication command results from the listener.
|
||||
client.admin.command("ping")
|
||||
self.listener.reset()
|
||||
@ -1144,7 +1144,7 @@ class TestGlobalListener(IntegrationTest):
|
||||
# We plan to call register(), which internally modifies _LISTENERS.
|
||||
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
|
||||
monitoring.register(cls.listener)
|
||||
cls.client = single_client()
|
||||
cls.client = cls.unmanaged_single_client()
|
||||
# Get one (authenticated) socket in the pool.
|
||||
cls.client.pymongo_test.command("ping")
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ from pymongo.hello import HelloCompat
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test.utils import delay, get_pool, joinall, rs_or_single_client
|
||||
from test.utils import delay, get_pool, joinall
|
||||
|
||||
from pymongo.socket_checker import SocketChecker
|
||||
from pymongo.synchronous.pool import Pool, PoolOptions
|
||||
@ -151,7 +151,7 @@ class _TestPoolingBase(IntegrationTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.c = rs_or_single_client()
|
||||
self.c = self.rs_or_single_client()
|
||||
db = self.c[DB]
|
||||
db.unique.drop()
|
||||
db.test.drop()
|
||||
@ -378,7 +378,7 @@ class TestPooling(_TestPoolingBase):
|
||||
socket_info.close_conn(None)
|
||||
|
||||
def test_maxConnecting(self):
|
||||
client = rs_or_single_client()
|
||||
client = self.rs_or_single_client()
|
||||
self.addCleanup(client.close)
|
||||
self.client.test.test.insert_one({})
|
||||
self.addCleanup(self.client.test.test.delete_many, {})
|
||||
@ -415,7 +415,7 @@ class TestPooling(_TestPoolingBase):
|
||||
|
||||
@client_context.require_failCommand_appName
|
||||
def test_csot_timeout_message(self):
|
||||
client = rs_or_single_client(appName="connectionTimeoutApp")
|
||||
client = self.rs_or_single_client(appName="connectionTimeoutApp")
|
||||
self.addCleanup(client.close)
|
||||
# Mock an operation failing due to pymongo.timeout().
|
||||
mock_connection_timeout = {
|
||||
@ -440,7 +440,7 @@ class TestPooling(_TestPoolingBase):
|
||||
|
||||
@client_context.require_failCommand_appName
|
||||
def test_socket_timeout_message(self):
|
||||
client = rs_or_single_client(socketTimeoutMS=500, appName="connectionTimeoutApp")
|
||||
client = self.rs_or_single_client(socketTimeoutMS=500, appName="connectionTimeoutApp")
|
||||
self.addCleanup(client.close)
|
||||
# Mock an operation failing due to socketTimeoutMS.
|
||||
mock_connection_timeout = {
|
||||
@ -479,7 +479,7 @@ class TestPooling(_TestPoolingBase):
|
||||
},
|
||||
}
|
||||
|
||||
client = rs_or_single_client(
|
||||
client = self.rs_or_single_client(
|
||||
connectTimeoutMS=500,
|
||||
socketTimeoutMS=500,
|
||||
appName="connectionTimeoutApp",
|
||||
@ -502,7 +502,7 @@ class TestPooling(_TestPoolingBase):
|
||||
class TestPoolMaxSize(_TestPoolingBase):
|
||||
def test_max_pool_size(self):
|
||||
max_pool_size = 4
|
||||
c = rs_or_single_client(maxPoolSize=max_pool_size)
|
||||
c = self.rs_or_single_client(maxPoolSize=max_pool_size)
|
||||
self.addCleanup(c.close)
|
||||
collection = c[DB].test
|
||||
|
||||
@ -538,7 +538,7 @@ class TestPoolMaxSize(_TestPoolingBase):
|
||||
self.assertEqual(0, cx_pool.requests)
|
||||
|
||||
def test_max_pool_size_none(self):
|
||||
c = rs_or_single_client(maxPoolSize=None)
|
||||
c = self.rs_or_single_client(maxPoolSize=None)
|
||||
self.addCleanup(c.close)
|
||||
collection = c[DB].test
|
||||
|
||||
@ -570,7 +570,7 @@ class TestPoolMaxSize(_TestPoolingBase):
|
||||
self.assertEqual(cx_pool.max_pool_size, float("inf"))
|
||||
|
||||
def test_max_pool_size_zero(self):
|
||||
c = rs_or_single_client(maxPoolSize=0)
|
||||
c = self.rs_or_single_client(maxPoolSize=0)
|
||||
self.addCleanup(c.close)
|
||||
pool = get_pool(c)
|
||||
self.assertEqual(pool.max_pool_size, float("inf"))
|
||||
|
||||
@ -21,7 +21,7 @@ import unittest
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context
|
||||
from test.utils import OvertCommandListener, rs_or_single_client
|
||||
from test.utils import OvertCommandListener
|
||||
|
||||
from bson.son import SON
|
||||
from pymongo.errors import OperationFailure
|
||||
@ -36,7 +36,7 @@ class TestReadConcern(IntegrationTest):
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.db = cls.client.pymongo_test
|
||||
client_context.client.pymongo_test.create_collection("coll")
|
||||
|
||||
@ -67,7 +67,7 @@ class TestReadConcern(IntegrationTest):
|
||||
|
||||
def test_read_concern_uri(self):
|
||||
uri = f"mongodb://{client_context.pair}/?readConcernLevel=majority"
|
||||
client = rs_or_single_client(uri, connect=False)
|
||||
client = self.rs_or_single_client(uri, connect=False)
|
||||
self.assertEqual(ReadConcern("majority"), client.read_concern)
|
||||
|
||||
def test_invalid_read_concern(self):
|
||||
|
||||
@ -30,8 +30,6 @@ from test import IntegrationTest, SkipTest, client_context, connected, unittest
|
||||
from test.utils import (
|
||||
OvertCommandListener,
|
||||
one,
|
||||
rs_client,
|
||||
single_client,
|
||||
wait_until,
|
||||
)
|
||||
from test.version import Version
|
||||
@ -58,7 +56,7 @@ from pymongo.write_concern import WriteConcern
|
||||
class TestSelections(IntegrationTest):
|
||||
@client_context.require_connection
|
||||
def test_bool(self):
|
||||
client = single_client()
|
||||
client = self.single_client()
|
||||
|
||||
wait_until(lambda: client.address, "discover primary")
|
||||
selection = Selection.from_topology_description(client._topology.description)
|
||||
@ -128,7 +126,7 @@ class TestReadPreferencesBase(IntegrationTest):
|
||||
return None
|
||||
|
||||
def assertReadsFrom(self, expected, **kwargs):
|
||||
c = rs_client(**kwargs)
|
||||
c = self.rs_client(**kwargs)
|
||||
wait_until(lambda: len(c.nodes - c.arbiters) == client_context.w, "discovered all nodes")
|
||||
|
||||
used = self.read_from_which_kind(c)
|
||||
@ -139,7 +137,7 @@ class TestSingleSecondaryOk(TestReadPreferencesBase):
|
||||
def test_reads_from_secondary(self):
|
||||
host, port = next(iter(self.client.secondaries))
|
||||
# Direct connection to a secondary.
|
||||
client = single_client(host, port)
|
||||
client = self.single_client(host, port)
|
||||
self.assertFalse(client.is_primary)
|
||||
|
||||
# Regardless of read preference, we should be able to do
|
||||
@ -175,19 +173,21 @@ class TestReadPreferences(TestReadPreferencesBase):
|
||||
ReadPreference.SECONDARY_PREFERRED,
|
||||
ReadPreference.NEAREST,
|
||||
):
|
||||
self.assertEqual(mode, rs_client(read_preference=mode).read_preference)
|
||||
self.assertEqual(mode, self.rs_client(read_preference=mode).read_preference)
|
||||
|
||||
self.assertRaises(TypeError, rs_client, read_preference="foo")
|
||||
self.assertRaises(TypeError, self.rs_client, read_preference="foo")
|
||||
|
||||
def test_tag_sets_validation(self):
|
||||
S = Secondary(tag_sets=[{}])
|
||||
self.assertEqual([{}], rs_client(read_preference=S).read_preference.tag_sets)
|
||||
self.assertEqual([{}], self.rs_client(read_preference=S).read_preference.tag_sets)
|
||||
|
||||
S = Secondary(tag_sets=[{"k": "v"}])
|
||||
self.assertEqual([{"k": "v"}], rs_client(read_preference=S).read_preference.tag_sets)
|
||||
self.assertEqual([{"k": "v"}], self.rs_client(read_preference=S).read_preference.tag_sets)
|
||||
|
||||
S = Secondary(tag_sets=[{"k": "v"}, {}])
|
||||
self.assertEqual([{"k": "v"}, {}], rs_client(read_preference=S).read_preference.tag_sets)
|
||||
self.assertEqual(
|
||||
[{"k": "v"}, {}], self.rs_client(read_preference=S).read_preference.tag_sets
|
||||
)
|
||||
|
||||
self.assertRaises(ValueError, Secondary, tag_sets=[])
|
||||
|
||||
@ -200,20 +200,22 @@ class TestReadPreferences(TestReadPreferencesBase):
|
||||
|
||||
def test_threshold_validation(self):
|
||||
self.assertEqual(
|
||||
17, rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms
|
||||
17, self.rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
42, rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms
|
||||
42, self.rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
666, rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms
|
||||
666, self.rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms
|
||||
)
|
||||
|
||||
self.assertEqual(0, rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms)
|
||||
self.assertEqual(
|
||||
0, self.rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms
|
||||
)
|
||||
|
||||
self.assertRaises(ValueError, rs_client, localthresholdms=-1)
|
||||
self.assertRaises(ValueError, self.rs_client, localthresholdms=-1)
|
||||
|
||||
def test_zero_latency(self):
|
||||
ping_times: set = set()
|
||||
@ -223,7 +225,7 @@ class TestReadPreferences(TestReadPreferencesBase):
|
||||
for ping_time, host in zip(ping_times, self.client.nodes):
|
||||
ServerDescription._host_to_round_trip_time[host] = ping_time
|
||||
try:
|
||||
client = connected(rs_client(readPreference="nearest", localThresholdMS=0))
|
||||
client = connected(self.rs_client(readPreference="nearest", localThresholdMS=0))
|
||||
wait_until(lambda: client.nodes == self.client.nodes, "discovered all nodes")
|
||||
host = self.read_from_which_host(client)
|
||||
for _ in range(5):
|
||||
@ -236,7 +238,7 @@ class TestReadPreferences(TestReadPreferencesBase):
|
||||
|
||||
def test_primary_with_tags(self):
|
||||
# Tags not allowed with PRIMARY
|
||||
self.assertRaises(ConfigurationError, rs_client, tag_sets=[{"dc": "ny"}])
|
||||
self.assertRaises(ConfigurationError, self.rs_client, tag_sets=[{"dc": "ny"}])
|
||||
|
||||
def test_primary_preferred(self):
|
||||
self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED)
|
||||
@ -250,7 +252,9 @@ class TestReadPreferences(TestReadPreferencesBase):
|
||||
def test_nearest(self):
|
||||
# With high localThresholdMS, expect to read from any
|
||||
# member
|
||||
c = rs_client(read_preference=ReadPreference.NEAREST, localThresholdMS=10000) # 10 seconds
|
||||
c = self.rs_client(
|
||||
read_preference=ReadPreference.NEAREST, localThresholdMS=10000
|
||||
) # 10 seconds
|
||||
|
||||
data_members = {self.client.primary} | self.client.secondaries
|
||||
|
||||
@ -540,7 +544,7 @@ class TestMongosAndReadPreference(IntegrationTest):
|
||||
if client_context.supports_secondary_read_pref:
|
||||
cases["secondary"] = Secondary
|
||||
listener = OvertCommandListener()
|
||||
client = rs_client(event_listeners=[listener])
|
||||
client = self.rs_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client.admin.command("ping")
|
||||
for _mode, cls in cases.items():
|
||||
@ -667,13 +671,13 @@ class TestMongosAndReadPreference(IntegrationTest):
|
||||
else:
|
||||
self.fail("mongos accepted invalid staleness")
|
||||
|
||||
coll = single_client(
|
||||
coll = self.single_client(
|
||||
readPreference="secondaryPreferred", maxStalenessSeconds=120
|
||||
).pymongo_test.test
|
||||
# No error
|
||||
coll.find_one()
|
||||
|
||||
coll = single_client(
|
||||
coll = self.single_client(
|
||||
readPreference="secondaryPreferred", maxStalenessSeconds=10
|
||||
).pymongo_test.test
|
||||
try:
|
||||
|
||||
@ -24,12 +24,7 @@ sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
from test.utils import (
|
||||
EventListener,
|
||||
disable_replication,
|
||||
enable_replication,
|
||||
rs_or_single_client,
|
||||
)
|
||||
from test.utils import EventListener
|
||||
|
||||
from pymongo import DESCENDING
|
||||
from pymongo.errors import (
|
||||
@ -51,7 +46,7 @@ class TestReadWriteConcernSpec(IntegrationTest):
|
||||
def test_omit_default_read_write_concern(self):
|
||||
listener = EventListener()
|
||||
# Client with default readConcern and writeConcern
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
collection = client.pymongo_test.collection
|
||||
# Prepare for tests of find() and aggregate().
|
||||
@ -104,7 +99,9 @@ class TestReadWriteConcernSpec(IntegrationTest):
|
||||
def assertWriteOpsRaise(self, write_concern, expected_exception):
|
||||
wc = write_concern.document
|
||||
# Set socket timeout to avoid indefinite stalls
|
||||
client = rs_or_single_client(w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000)
|
||||
client = self.rs_or_single_client(
|
||||
w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000
|
||||
)
|
||||
db = client.get_database("pymongo_test")
|
||||
coll = db.test
|
||||
|
||||
@ -167,9 +164,9 @@ class TestReadWriteConcernSpec(IntegrationTest):
|
||||
@client_context.require_test_commands
|
||||
def test_raise_wtimeout(self):
|
||||
self.addCleanup(client_context.client.drop_database, "pymongo_test")
|
||||
self.addCleanup(enable_replication, client_context.client)
|
||||
self.addCleanup(self.enable_replication, client_context.client)
|
||||
# Disable replication to guarantee a wtimeout error.
|
||||
disable_replication(client_context.client)
|
||||
self.disable_replication(client_context.client)
|
||||
self.assertWriteOpsRaise(WriteConcern(w=client_context.w, wtimeout=1), WTimeoutError)
|
||||
|
||||
@client_context.require_failCommand_fail_point
|
||||
@ -209,7 +206,7 @@ class TestReadWriteConcernSpec(IntegrationTest):
|
||||
@client_context.require_version_min(4, 9)
|
||||
def test_write_error_details_exposes_errinfo(self):
|
||||
listener = EventListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
db = client.errinfotest
|
||||
self.addCleanup(client.drop_database, "errinfotest")
|
||||
|
||||
@ -34,7 +34,6 @@ from test import (
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
OvertCommandListener,
|
||||
rs_or_single_client,
|
||||
set_fail_point,
|
||||
)
|
||||
|
||||
@ -93,7 +92,9 @@ class TestPoolPausedError(IntegrationTest):
|
||||
self.skipTest("Test is flakey on PyPy")
|
||||
cmap_listener = CMAPListener()
|
||||
cmd_listener = OvertCommandListener()
|
||||
client = rs_or_single_client(maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener])
|
||||
client = self.rs_or_single_client(
|
||||
maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
for _ in range(10):
|
||||
cmap_listener.reset()
|
||||
@ -163,13 +164,13 @@ class TestRetryableReads(IntegrationTest):
|
||||
mongos_clients = []
|
||||
|
||||
for mongos in client_context.mongos_seeds().split(","):
|
||||
client = rs_or_single_client(mongos)
|
||||
client = self.rs_or_single_client(mongos)
|
||||
set_fail_point(client, fail_command)
|
||||
self.addCleanup(client.close)
|
||||
mongos_clients.append(client)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(
|
||||
client = self.rs_or_single_client(
|
||||
client_context.mongos_seeds(),
|
||||
appName="retryableReadTest",
|
||||
event_listeners=[listener],
|
||||
|
||||
@ -28,7 +28,6 @@ from test.utils import (
|
||||
DeprecationFilter,
|
||||
EventListener,
|
||||
OvertCommandListener,
|
||||
rs_or_single_client,
|
||||
set_fail_point,
|
||||
)
|
||||
from test.version import Version
|
||||
@ -145,7 +144,7 @@ class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
cls.client = rs_or_single_client(retryWrites=True)
|
||||
cls.client = cls.unmanaged_rs_or_single_client(retryWrites=True)
|
||||
cls.db = cls.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
@ -181,7 +180,9 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = rs_or_single_client(retryWrites=True, event_listeners=[cls.listener])
|
||||
cls.client = cls.unmanaged_rs_or_single_client(
|
||||
retryWrites=True, event_listeners=[cls.listener]
|
||||
)
|
||||
cls.db = cls.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
@ -204,7 +205,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
|
||||
def test_supported_single_statement_no_retry(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(retryWrites=False, event_listeners=[listener])
|
||||
client = self.rs_or_single_client(retryWrites=False, event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test):
|
||||
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
|
||||
@ -297,7 +298,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
def test_server_selection_timeout_not_retried(self):
|
||||
"""A ServerSelectionTimeoutError is not retried."""
|
||||
listener = OvertCommandListener()
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
"somedomainthatdoesntexist.org",
|
||||
serverSelectionTimeoutMS=1,
|
||||
retryWrites=True,
|
||||
@ -317,7 +318,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
original error.
|
||||
"""
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(retryWrites=True, event_listeners=[listener])
|
||||
client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
topology = client._topology
|
||||
select_server = topology.select_server
|
||||
@ -443,13 +444,13 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
mongos_clients = []
|
||||
|
||||
for mongos in client_context.mongos_seeds().split(","):
|
||||
client = rs_or_single_client(mongos)
|
||||
client = self.rs_or_single_client(mongos)
|
||||
set_fail_point(client, fail_command)
|
||||
self.addCleanup(client.close)
|
||||
mongos_clients.append(client)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(
|
||||
client = self.rs_or_single_client(
|
||||
client_context.mongos_seeds(),
|
||||
appName="retryableWriteTest",
|
||||
event_listeners=[listener],
|
||||
@ -492,7 +493,7 @@ class TestWriteConcernError(IntegrationTest):
|
||||
@client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05)
|
||||
def test_RetryableWriteError_error_label(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(retryWrites=True, event_listeners=[listener])
|
||||
client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
|
||||
# Ensure collection exists.
|
||||
@ -551,7 +552,9 @@ class TestPoolPausedError(IntegrationTest):
|
||||
def test_pool_paused_error_is_retryable(self):
|
||||
cmap_listener = CMAPListener()
|
||||
cmd_listener = OvertCommandListener()
|
||||
client = rs_or_single_client(maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener])
|
||||
client = self.rs_or_single_client(
|
||||
maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
for _ in range(10):
|
||||
cmap_listener.reset()
|
||||
@ -613,7 +616,7 @@ class TestPoolPausedError(IntegrationTest):
|
||||
self,
|
||||
):
|
||||
cmd_listener = InsertEventListener()
|
||||
client = rs_or_single_client(retryWrites=True, event_listeners=[cmd_listener])
|
||||
client = self.rs_or_single_client(retryWrites=True, event_listeners=[cmd_listener])
|
||||
client.test.test.drop()
|
||||
self.addCleanup(client.close)
|
||||
cmd_listener.reset()
|
||||
@ -650,7 +653,7 @@ class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest):
|
||||
the first attempt fails before sending the command.
|
||||
"""
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(retryWrites=True, event_listeners=[listener])
|
||||
client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
topology = client._topology
|
||||
select_server = topology.select_server
|
||||
|
||||
@ -25,7 +25,6 @@ sys.path[0:0] = [""]
|
||||
from test import IntegrationTest, client_context, client_knobs, unittest
|
||||
from test.utils import (
|
||||
ServerAndTopologyEventListener,
|
||||
rs_or_single_client,
|
||||
server_name_to_type,
|
||||
wait_until,
|
||||
)
|
||||
@ -279,7 +278,7 @@ class TestSdamMonitoring(IntegrationTest):
|
||||
cls.knobs.enable()
|
||||
cls.listener = ServerAndTopologyEventListener()
|
||||
retry_writes = client_context.supports_transactions()
|
||||
cls.test_client = rs_or_single_client(
|
||||
cls.test_client = cls.unmanaged_rs_or_single_client(
|
||||
event_listeners=[cls.listener], retryWrites=retry_writes
|
||||
)
|
||||
cls.coll = cls.test_client[cls.client.db.name].test
|
||||
|
||||
@ -33,7 +33,6 @@ from test import IntegrationTest, client_context, unittest
|
||||
from test.utils import (
|
||||
EventListener,
|
||||
FunctionCallRecorder,
|
||||
rs_or_single_client,
|
||||
wait_until,
|
||||
)
|
||||
from test.utils_selection_tests import (
|
||||
@ -76,7 +75,9 @@ class TestCustomServerSelectorFunction(IntegrationTest):
|
||||
|
||||
# Initialize client with appropriate listeners.
|
||||
listener = EventListener()
|
||||
client = rs_or_single_client(server_selector=custom_selector, event_listeners=[listener])
|
||||
client = self.rs_or_single_client(
|
||||
server_selector=custom_selector, event_listeners=[listener]
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll
|
||||
self.addCleanup(client.drop_database, "testdb")
|
||||
@ -117,7 +118,7 @@ class TestCustomServerSelectorFunction(IntegrationTest):
|
||||
selector = FunctionCallRecorder(lambda x: x)
|
||||
|
||||
# Client setup.
|
||||
mongo_client = rs_or_single_client(server_selector=selector)
|
||||
mongo_client = self.rs_or_single_client(server_selector=selector)
|
||||
test_collection = mongo_client.testdb.test_collection
|
||||
self.addCleanup(mongo_client.close)
|
||||
self.addCleanup(mongo_client.drop_database, "testdb")
|
||||
|
||||
@ -22,7 +22,6 @@ from test.utils import (
|
||||
OvertCommandListener,
|
||||
SpecTestCreator,
|
||||
get_pool,
|
||||
rs_client,
|
||||
wait_until,
|
||||
)
|
||||
from test.utils_selection_tests import create_topology
|
||||
@ -134,7 +133,7 @@ class TestProse(IntegrationTest):
|
||||
listener = OvertCommandListener()
|
||||
# PYTHON-2584: Use a large localThresholdMS to avoid the impact of
|
||||
# varying RTTs.
|
||||
client = rs_client(
|
||||
client = self.rs_client(
|
||||
client_context.mongos_seeds(),
|
||||
appName="loadBalancingTest",
|
||||
event_listeners=[listener],
|
||||
|
||||
@ -36,7 +36,6 @@ from test import (
|
||||
from test.utils import (
|
||||
EventListener,
|
||||
ExceptionCatchingThread,
|
||||
rs_or_single_client,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
@ -88,7 +87,7 @@ class TestSession(IntegrationTest):
|
||||
super()._setup_class()
|
||||
# Create a second client so we can make sure clients cannot share
|
||||
# sessions.
|
||||
cls.client2 = rs_or_single_client()
|
||||
cls.client2 = cls.unmanaged_rs_or_single_client()
|
||||
|
||||
# Redact no commands, so we can test user-admin commands have "lsid".
|
||||
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
|
||||
@ -103,7 +102,7 @@ class TestSession(IntegrationTest):
|
||||
def setUp(self):
|
||||
self.listener = SessionTestListener()
|
||||
self.session_checker_listener = SessionTestListener()
|
||||
self.client = rs_or_single_client(
|
||||
self.client = self.rs_or_single_client(
|
||||
event_listeners=[self.listener, self.session_checker_listener]
|
||||
)
|
||||
self.addCleanup(self.client.close)
|
||||
@ -200,7 +199,7 @@ class TestSession(IntegrationTest):
|
||||
failures = 0
|
||||
for _ in range(5):
|
||||
listener = EventListener()
|
||||
client = rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
|
||||
client = self.rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
|
||||
cursor = client.db.test.find({})
|
||||
ops: List[Tuple[Callable, List[Any]]] = [
|
||||
(client.db.test.find_one, [{"_id": 1}]),
|
||||
@ -283,7 +282,7 @@ class TestSession(IntegrationTest):
|
||||
def test_end_sessions(self):
|
||||
# Use a new client so that the tearDown hook does not error.
|
||||
listener = SessionTestListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
# Start many sessions.
|
||||
sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)]
|
||||
for s in sessions:
|
||||
@ -787,8 +786,7 @@ class TestSession(IntegrationTest):
|
||||
def test_unacknowledged_writes(self):
|
||||
# Ensure the collection exists.
|
||||
self.client.pymongo_test.test_unacked_writes.insert_one({})
|
||||
client = rs_or_single_client(w=0, event_listeners=[self.listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(w=0, event_listeners=[self.listener])
|
||||
db = client.pymongo_test
|
||||
coll = db.test_unacked_writes
|
||||
ops: list = [
|
||||
@ -836,7 +834,7 @@ class TestCausalConsistency(UnitTest):
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
cls.listener = SessionTestListener()
|
||||
cls.client = rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
@ -1137,8 +1135,7 @@ class TestClusterTime(IntegrationTest):
|
||||
def test_cluster_time(self):
|
||||
listener = SessionTestListener()
|
||||
# Prevent heartbeats from updating $clusterTime between operations.
|
||||
client = rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999)
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999)
|
||||
collection = client.pymongo_test.collection
|
||||
# Prepare for tests of find() and aggregate().
|
||||
collection.insert_many([{} for _ in range(10)])
|
||||
|
||||
@ -21,7 +21,7 @@ from typing import Any
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import client_knobs, unittest
|
||||
from test import PyMongoTestCase, client_knobs, unittest
|
||||
from test.utils import FunctionCallRecorder, wait_until
|
||||
|
||||
import pymongo
|
||||
@ -86,7 +86,7 @@ class SrvPollingKnobs:
|
||||
self.disable()
|
||||
|
||||
|
||||
class TestSrvPolling(unittest.TestCase):
|
||||
class TestSrvPolling(PyMongoTestCase):
|
||||
BASE_SRV_RESPONSE = [
|
||||
("localhost.test.build.10gen.cc", 27017),
|
||||
("localhost.test.build.10gen.cc", 27018),
|
||||
@ -167,7 +167,7 @@ class TestSrvPolling(unittest.TestCase):
|
||||
|
||||
# Patch timeouts to ensure short test running times.
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = MongoClient(self.CONNECTION_STRING)
|
||||
client = self.simple_client(self.CONNECTION_STRING)
|
||||
self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client)
|
||||
# Patch list of hosts returned by DNS query.
|
||||
with SrvPollingKnobs(
|
||||
@ -231,7 +231,7 @@ class TestSrvPolling(unittest.TestCase):
|
||||
count_resolver_calls=True,
|
||||
):
|
||||
# Client uses unpatched method to get initial nodelist
|
||||
client = MongoClient(self.CONNECTION_STRING)
|
||||
client = self.simple_client(self.CONNECTION_STRING)
|
||||
# Invalid DNS resolver response should not change nodelist.
|
||||
self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client)
|
||||
|
||||
@ -264,8 +264,7 @@ class TestSrvPolling(unittest.TestCase):
|
||||
return response
|
||||
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=0)
|
||||
self.addCleanup(client.close)
|
||||
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=0)
|
||||
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
|
||||
self.assert_nodelist_change(response, client)
|
||||
|
||||
@ -279,8 +278,7 @@ class TestSrvPolling(unittest.TestCase):
|
||||
return response
|
||||
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=2)
|
||||
self.addCleanup(client.close)
|
||||
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2)
|
||||
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
|
||||
self.assert_nodelist_change(response, client)
|
||||
|
||||
@ -295,8 +293,7 @@ class TestSrvPolling(unittest.TestCase):
|
||||
return response
|
||||
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=2)
|
||||
self.addCleanup(client.close)
|
||||
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2)
|
||||
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
|
||||
sleep(2 * common.MIN_SRV_RESCAN_INTERVAL)
|
||||
final_topology = set(client.topology_description.server_descriptions())
|
||||
@ -305,8 +302,7 @@ class TestSrvPolling(unittest.TestCase):
|
||||
|
||||
def test_does_not_flipflop(self):
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=1)
|
||||
self.addCleanup(client.close)
|
||||
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=1)
|
||||
old = set(client.topology_description.server_descriptions())
|
||||
sleep(4 * WAIT_TIME)
|
||||
new = set(client.topology_description.server_descriptions())
|
||||
@ -323,7 +319,7 @@ class TestSrvPolling(unittest.TestCase):
|
||||
return response
|
||||
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
"mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname"
|
||||
)
|
||||
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
|
||||
@ -340,7 +336,7 @@ class TestSrvPolling(unittest.TestCase):
|
||||
min_srv_rescan_interval=WAIT_TIME,
|
||||
nodelist_callback=resolver_response,
|
||||
):
|
||||
client = MongoClient(self.CONNECTION_STRING)
|
||||
client = self.simple_client(self.CONNECTION_STRING)
|
||||
self.assertRaises(
|
||||
AssertionError, self.assert_nodelist_change, modified, client, timeout=WAIT_TIME / 2
|
||||
)
|
||||
|
||||
101
test/test_ssl.py
101
test/test_ssl.py
@ -24,6 +24,7 @@ sys.path[0:0] = [""]
|
||||
from test import (
|
||||
HAVE_IPADDRESS,
|
||||
IntegrationTest,
|
||||
PyMongoTestCase,
|
||||
SkipTest,
|
||||
client_context,
|
||||
connected,
|
||||
@ -82,45 +83,45 @@ MONGODB_X509_USERNAME = "C=US,ST=New York,L=New York City,O=MDB,OU=Drivers,CN=cl
|
||||
# use 'localhost' for the hostname of all hosts.
|
||||
|
||||
|
||||
class TestClientSSL(unittest.TestCase):
|
||||
class TestClientSSL(PyMongoTestCase):
|
||||
@unittest.skipIf(HAVE_SSL, "The ssl module is available, can't test what happens without it.")
|
||||
def test_no_ssl_module(self):
|
||||
# Explicit
|
||||
self.assertRaises(ConfigurationError, MongoClient, ssl=True)
|
||||
self.assertRaises(ConfigurationError, self.simple_client, ssl=True)
|
||||
|
||||
# Implied
|
||||
self.assertRaises(ConfigurationError, MongoClient, tlsCertificateKeyFile=CLIENT_PEM)
|
||||
self.assertRaises(ConfigurationError, self.simple_client, tlsCertificateKeyFile=CLIENT_PEM)
|
||||
|
||||
@unittest.skipUnless(HAVE_SSL, "The ssl module is not available.")
|
||||
@ignore_deprecations
|
||||
def test_config_ssl(self):
|
||||
# Tests various ssl configurations
|
||||
self.assertRaises(ValueError, MongoClient, ssl="foo")
|
||||
self.assertRaises(ValueError, self.simple_client, ssl="foo")
|
||||
self.assertRaises(
|
||||
ConfigurationError, MongoClient, tls=False, tlsCertificateKeyFile=CLIENT_PEM
|
||||
ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM
|
||||
)
|
||||
self.assertRaises(TypeError, MongoClient, ssl=0)
|
||||
self.assertRaises(TypeError, MongoClient, ssl=5.5)
|
||||
self.assertRaises(TypeError, MongoClient, ssl=[])
|
||||
self.assertRaises(TypeError, self.simple_client, ssl=0)
|
||||
self.assertRaises(TypeError, self.simple_client, ssl=5.5)
|
||||
self.assertRaises(TypeError, self.simple_client, ssl=[])
|
||||
|
||||
self.assertRaises(IOError, MongoClient, tlsCertificateKeyFile="NoSuchFile")
|
||||
self.assertRaises(TypeError, MongoClient, tlsCertificateKeyFile=True)
|
||||
self.assertRaises(TypeError, MongoClient, tlsCertificateKeyFile=[])
|
||||
self.assertRaises(IOError, self.simple_client, tlsCertificateKeyFile="NoSuchFile")
|
||||
self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=True)
|
||||
self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=[])
|
||||
|
||||
# Test invalid combinations
|
||||
self.assertRaises(
|
||||
ConfigurationError, MongoClient, tls=False, tlsCertificateKeyFile=CLIENT_PEM
|
||||
ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM
|
||||
)
|
||||
self.assertRaises(ConfigurationError, MongoClient, tls=False, tlsCAFile=CA_PEM)
|
||||
self.assertRaises(ConfigurationError, MongoClient, tls=False, tlsCRLFile=CRL_PEM)
|
||||
self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCAFile=CA_PEM)
|
||||
self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCRLFile=CRL_PEM)
|
||||
self.assertRaises(
|
||||
ConfigurationError, MongoClient, tls=False, tlsAllowInvalidCertificates=False
|
||||
ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidCertificates=False
|
||||
)
|
||||
self.assertRaises(
|
||||
ConfigurationError, MongoClient, tls=False, tlsAllowInvalidHostnames=False
|
||||
ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidHostnames=False
|
||||
)
|
||||
self.assertRaises(
|
||||
ConfigurationError, MongoClient, tls=False, tlsDisableOCSPEndpointCheck=False
|
||||
ConfigurationError, self.simple_client, tls=False, tlsDisableOCSPEndpointCheck=False
|
||||
)
|
||||
|
||||
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
|
||||
@ -174,7 +175,7 @@ class TestSSL(IntegrationTest):
|
||||
if not hasattr(ssl, "SSLContext") and not _ssl.IS_PYOPENSSL:
|
||||
self.assertRaises(
|
||||
ConfigurationError,
|
||||
MongoClient,
|
||||
self.simple_client,
|
||||
"localhost",
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM,
|
||||
@ -184,7 +185,7 @@ class TestSSL(IntegrationTest):
|
||||
)
|
||||
else:
|
||||
connected(
|
||||
MongoClient(
|
||||
self.simple_client(
|
||||
"localhost",
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM,
|
||||
@ -201,7 +202,7 @@ class TestSSL(IntegrationTest):
|
||||
"&tlsCAFile=%s&serverSelectionTimeoutMS=5000"
|
||||
)
|
||||
connected(
|
||||
MongoClient(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type]
|
||||
self.simple_client(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@client_context.require_tlsCertificateKeyFile
|
||||
@ -215,7 +216,7 @@ class TestSSL(IntegrationTest):
|
||||
#
|
||||
|
||||
# test that setting tlsCertificateKeyFile causes ssl to be set to True
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
client_context.host,
|
||||
client_context.port,
|
||||
tlsAllowInvalidCertificates=True,
|
||||
@ -223,7 +224,7 @@ class TestSSL(IntegrationTest):
|
||||
)
|
||||
response = client.admin.command(HelloCompat.LEGACY_CMD)
|
||||
if "setName" in response:
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
client_context.pair,
|
||||
replicaSet=response["setName"],
|
||||
w=len(response["hosts"]),
|
||||
@ -242,7 +243,7 @@ class TestSSL(IntegrationTest):
|
||||
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
|
||||
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
|
||||
#
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
"localhost",
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
@ -257,7 +258,7 @@ class TestSSL(IntegrationTest):
|
||||
"Cannot validate hostname in the certificate"
|
||||
)
|
||||
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
"localhost",
|
||||
replicaSet=response["setName"],
|
||||
w=len(response["hosts"]),
|
||||
@ -270,7 +271,7 @@ class TestSSL(IntegrationTest):
|
||||
self.assertClientWorks(client)
|
||||
|
||||
if HAVE_IPADDRESS:
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
"127.0.0.1",
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
@ -292,7 +293,7 @@ class TestSSL(IntegrationTest):
|
||||
"mongodb://localhost/?ssl=true&tlsCertificateKeyFile=%s&tlsAllowInvalidCertificates"
|
||||
"=%s&tlsCAFile=%s&tlsAllowInvalidHostnames=false"
|
||||
)
|
||||
client = MongoClient(uri_fmt % (CLIENT_PEM, "true", CA_PEM))
|
||||
client = self.simple_client(uri_fmt % (CLIENT_PEM, "true", CA_PEM))
|
||||
self.assertClientWorks(client)
|
||||
|
||||
@client_context.require_tlsCertificateKeyFile
|
||||
@ -316,7 +317,7 @@ class TestSSL(IntegrationTest):
|
||||
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
connected(
|
||||
MongoClient(
|
||||
self.simple_client(
|
||||
"server",
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
@ -328,7 +329,7 @@ class TestSSL(IntegrationTest):
|
||||
)
|
||||
|
||||
connected(
|
||||
MongoClient(
|
||||
self.simple_client(
|
||||
"server",
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
@ -343,7 +344,7 @@ class TestSSL(IntegrationTest):
|
||||
if "setName" in response:
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
connected(
|
||||
MongoClient(
|
||||
self.simple_client(
|
||||
"server",
|
||||
replicaSet=response["setName"],
|
||||
ssl=True,
|
||||
@ -356,7 +357,7 @@ class TestSSL(IntegrationTest):
|
||||
)
|
||||
|
||||
connected(
|
||||
MongoClient(
|
||||
self.simple_client(
|
||||
"server",
|
||||
replicaSet=response["setName"],
|
||||
ssl=True,
|
||||
@ -375,7 +376,7 @@ class TestSSL(IntegrationTest):
|
||||
if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or _ssl.IS_PYOPENSSL:
|
||||
self.assertRaises(
|
||||
ConfigurationError,
|
||||
MongoClient,
|
||||
self.simple_client,
|
||||
"localhost",
|
||||
ssl=True,
|
||||
tlsCAFile=CA_PEM,
|
||||
@ -384,7 +385,7 @@ class TestSSL(IntegrationTest):
|
||||
)
|
||||
else:
|
||||
connected(
|
||||
MongoClient(
|
||||
self.simple_client(
|
||||
"localhost",
|
||||
ssl=True,
|
||||
tlsCAFile=CA_PEM,
|
||||
@ -395,7 +396,7 @@ class TestSSL(IntegrationTest):
|
||||
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
connected(
|
||||
MongoClient(
|
||||
self.simple_client(
|
||||
"localhost",
|
||||
ssl=True,
|
||||
tlsCAFile=CA_PEM,
|
||||
@ -406,7 +407,7 @@ class TestSSL(IntegrationTest):
|
||||
)
|
||||
|
||||
uri_fmt = "mongodb://localhost/?ssl=true&tlsCAFile=%s&serverSelectionTimeoutMS=1000"
|
||||
connected(MongoClient(uri_fmt % (CA_PEM,), **self.credentials)) # type: ignore
|
||||
connected(self.simple_client(uri_fmt % (CA_PEM,), **self.credentials)) # type: ignore
|
||||
|
||||
uri_fmt = (
|
||||
"mongodb://localhost/?ssl=true&tlsCRLFile=%s"
|
||||
@ -414,7 +415,7 @@ class TestSSL(IntegrationTest):
|
||||
)
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
connected(
|
||||
MongoClient(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type]
|
||||
self.simple_client(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@client_context.require_tlsCertificateKeyFile
|
||||
@ -431,12 +432,14 @@ class TestSSL(IntegrationTest):
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
# Server cert is verified but hostname matching fails
|
||||
connected(
|
||||
MongoClient("server", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials) # type: ignore[arg-type]
|
||||
self.simple_client(
|
||||
"server", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials
|
||||
) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Server cert is verified. Disable hostname matching.
|
||||
connected(
|
||||
MongoClient(
|
||||
self.simple_client(
|
||||
"server",
|
||||
ssl=True,
|
||||
tlsAllowInvalidHostnames=True,
|
||||
@ -447,12 +450,14 @@ class TestSSL(IntegrationTest):
|
||||
|
||||
# Server cert and hostname are verified.
|
||||
connected(
|
||||
MongoClient("localhost", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials) # type: ignore[arg-type]
|
||||
self.simple_client(
|
||||
"localhost", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials
|
||||
) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Server cert and hostname are verified.
|
||||
connected(
|
||||
MongoClient(
|
||||
self.simple_client(
|
||||
"mongodb://localhost/?ssl=true&serverSelectionTimeoutMS=1000",
|
||||
**self.credentials, # type: ignore[arg-type]
|
||||
)
|
||||
@ -472,7 +477,7 @@ class TestSSL(IntegrationTest):
|
||||
ssl_support.HAVE_WINCERTSTORE = False
|
||||
try:
|
||||
with self.assertRaises(ConfigurationError):
|
||||
MongoClient("mongodb://localhost/?ssl=true")
|
||||
self.simple_client("mongodb://localhost/?ssl=true")
|
||||
finally:
|
||||
ssl_support.HAVE_CERTIFI = have_certifi
|
||||
ssl_support.HAVE_WINCERTSTORE = have_wincertstore
|
||||
@ -536,7 +541,7 @@ class TestSSL(IntegrationTest):
|
||||
],
|
||||
)
|
||||
|
||||
noauth = MongoClient(
|
||||
noauth = self.simple_client(
|
||||
client_context.pair,
|
||||
ssl=True,
|
||||
tlsAllowInvalidCertificates=True,
|
||||
@ -548,7 +553,7 @@ class TestSSL(IntegrationTest):
|
||||
noauth.pymongo_test.test.find_one()
|
||||
|
||||
listener = EventListener()
|
||||
auth = MongoClient(
|
||||
auth = self.simple_client(
|
||||
client_context.pair,
|
||||
authMechanism="MONGODB-X509",
|
||||
ssl=True,
|
||||
@ -572,7 +577,7 @@ class TestSSL(IntegrationTest):
|
||||
host,
|
||||
port,
|
||||
)
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
@ -580,7 +585,7 @@ class TestSSL(IntegrationTest):
|
||||
client.pymongo_test.test.find_one()
|
||||
|
||||
uri = "mongodb://%s:%d/?authMechanism=MONGODB-X509" % (host, port)
|
||||
client = MongoClient(
|
||||
client = self.simple_client(
|
||||
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
@ -593,7 +598,7 @@ class TestSSL(IntegrationTest):
|
||||
port,
|
||||
)
|
||||
|
||||
bad_client = MongoClient(
|
||||
bad_client = self.simple_client(
|
||||
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
|
||||
)
|
||||
self.addCleanup(bad_client.close)
|
||||
@ -601,7 +606,7 @@ class TestSSL(IntegrationTest):
|
||||
with self.assertRaises(OperationFailure):
|
||||
bad_client.pymongo_test.test.find_one()
|
||||
|
||||
bad_client = MongoClient(
|
||||
bad_client = self.simple_client(
|
||||
client_context.pair,
|
||||
username="not the username",
|
||||
authMechanism="MONGODB-X509",
|
||||
@ -622,7 +627,7 @@ class TestSSL(IntegrationTest):
|
||||
)
|
||||
try:
|
||||
connected(
|
||||
MongoClient(
|
||||
self.simple_client(
|
||||
uri,
|
||||
ssl=True,
|
||||
tlsAllowInvalidCertificates=True,
|
||||
@ -648,7 +653,7 @@ class TestSSL(IntegrationTest):
|
||||
self.addCleanup(remove, temp_ca_bundle)
|
||||
# Add the CA cert file to the bundle.
|
||||
cat_files(temp_ca_bundle, CA_BUNDLE_PEM, CA_PEM)
|
||||
with MongoClient(
|
||||
with self.simple_client(
|
||||
"localhost", tls=True, tlsCertificateKeyFile=CLIENT_PEM, tlsCAFile=temp_ca_bundle
|
||||
) as client:
|
||||
self.assertTrue(client.admin.command("ping"))
|
||||
|
||||
@ -24,8 +24,6 @@ from test import IntegrationTest, client_context, unittest
|
||||
from test.utils import (
|
||||
HeartbeatEventListener,
|
||||
ServerEventListener,
|
||||
rs_or_single_client,
|
||||
single_client,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
@ -38,7 +36,7 @@ class TestStreamingProtocol(IntegrationTest):
|
||||
def test_failCommand_streaming(self):
|
||||
listener = ServerEventListener()
|
||||
hb_listener = HeartbeatEventListener()
|
||||
client = rs_or_single_client(
|
||||
client = self.rs_or_single_client(
|
||||
event_listeners=[listener, hb_listener],
|
||||
heartbeatFrequencyMS=500,
|
||||
appName="failingHeartbeatTest",
|
||||
@ -107,7 +105,7 @@ class TestStreamingProtocol(IntegrationTest):
|
||||
},
|
||||
}
|
||||
with self.fail_point(delay_hello):
|
||||
client = rs_or_single_client(
|
||||
client = self.rs_or_single_client(
|
||||
event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
@ -155,7 +153,7 @@ class TestStreamingProtocol(IntegrationTest):
|
||||
}
|
||||
with self.fail_point(fail_hello):
|
||||
start = time.time()
|
||||
client = single_client(
|
||||
client = self.single_client(
|
||||
appName="SDAMMinHeartbeatFrequencyTest", serverSelectionTimeoutMS=5000
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
@ -180,7 +178,7 @@ class TestStreamingProtocol(IntegrationTest):
|
||||
@client_context.require_failCommand_appName
|
||||
def test_heartbeat_awaited_flag(self):
|
||||
hb_listener = HeartbeatEventListener()
|
||||
client = single_client(
|
||||
client = self.single_client(
|
||||
event_listeners=[hb_listener],
|
||||
heartbeatFrequencyMS=500,
|
||||
appName="heartbeatEventAwaitedFlag",
|
||||
|
||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from io import BytesIO
|
||||
from test.utils_spec_runner import SpecRunner
|
||||
|
||||
from gridfs.synchronous.grid_file import GridFS, GridFSBucket
|
||||
|
||||
@ -25,8 +26,6 @@ sys.path[0:0] = [""]
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test.utils import (
|
||||
OvertCommandListener,
|
||||
rs_client,
|
||||
single_client,
|
||||
wait_until,
|
||||
)
|
||||
from typing import List
|
||||
@ -59,7 +58,18 @@ _IS_SYNC = True
|
||||
UNPIN_TEST_MAX_ATTEMPTS = 50
|
||||
|
||||
|
||||
class TestTransactions(IntegrationTest):
|
||||
class TransactionsBase(SpecRunner):
|
||||
def maybe_skip_scenario(self, test):
|
||||
super().maybe_skip_scenario(test)
|
||||
if (
|
||||
"secondary" in self.id()
|
||||
and not client_context.is_mongos
|
||||
and not client_context.has_secondaries
|
||||
):
|
||||
raise unittest.SkipTest("No secondaries")
|
||||
|
||||
|
||||
class TestTransactions(TransactionsBase):
|
||||
RUN_ON_SERVERLESS = True
|
||||
|
||||
@client_context.require_transactions
|
||||
@ -92,8 +102,7 @@ class TestTransactions(IntegrationTest):
|
||||
@client_context.require_transactions
|
||||
def test_transaction_write_concern_override(self):
|
||||
"""Test txn overrides Client/Database/Collection write_concern."""
|
||||
client = rs_client(w=0)
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_client(w=0)
|
||||
db = client.test
|
||||
coll = db.test
|
||||
coll.insert_one({})
|
||||
@ -146,12 +155,11 @@ class TestTransactions(IntegrationTest):
|
||||
def test_unpin_for_next_transaction(self):
|
||||
# Increase localThresholdMS and wait until both nodes are discovered
|
||||
# to avoid false positives.
|
||||
client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000)
|
||||
client = self.rs_client(client_context.mongos_seeds(), localThresholdMS=1000)
|
||||
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
||||
coll = client.test.test
|
||||
# Create the collection.
|
||||
coll.insert_one({})
|
||||
self.addCleanup(client.close)
|
||||
with client.start_session() as s:
|
||||
# Session is pinned to Mongos.
|
||||
with s.start_transaction():
|
||||
@ -174,12 +182,11 @@ class TestTransactions(IntegrationTest):
|
||||
def test_unpin_for_non_transaction_operation(self):
|
||||
# Increase localThresholdMS and wait until both nodes are discovered
|
||||
# to avoid false positives.
|
||||
client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000)
|
||||
client = self.rs_client(client_context.mongos_seeds(), localThresholdMS=1000)
|
||||
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
||||
coll = client.test.test
|
||||
# Create the collection.
|
||||
coll.insert_one({})
|
||||
self.addCleanup(client.close)
|
||||
with client.start_session() as s:
|
||||
# Session is pinned to Mongos.
|
||||
with s.start_transaction():
|
||||
@ -303,11 +310,10 @@ class TestTransactions(IntegrationTest):
|
||||
# Start a transaction with a batch of operations that needs to be
|
||||
# split.
|
||||
listener = OvertCommandListener()
|
||||
client = rs_client(event_listeners=[listener])
|
||||
client = self.rs_client(event_listeners=[listener])
|
||||
coll = client[self.db.name].test
|
||||
coll.delete_many({})
|
||||
listener.reset()
|
||||
self.addCleanup(client.close)
|
||||
self.addCleanup(coll.drop)
|
||||
large_str = "\0" * (1 * 1024 * 1024)
|
||||
ops: List[InsertOne[RawBSONDocument]] = [
|
||||
@ -332,8 +338,7 @@ class TestTransactions(IntegrationTest):
|
||||
|
||||
@client_context.require_transactions
|
||||
def test_transaction_direct_connection(self):
|
||||
client = single_client()
|
||||
self.addCleanup(client.close)
|
||||
client = self.single_client()
|
||||
coll = client.pymongo_test.test
|
||||
|
||||
# Make sure the collection exists.
|
||||
@ -389,14 +394,14 @@ class PatchSessionTimeout:
|
||||
client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.real_timeout
|
||||
|
||||
|
||||
class TestTransactionsConvenientAPI(IntegrationTest):
|
||||
class TestTransactionsConvenientAPI(TransactionsBase):
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.mongos_clients = []
|
||||
if client_context.supports_transactions():
|
||||
for address in client_context.mongoses:
|
||||
cls.mongos_clients.append(single_client("{}:{}".format(*address)))
|
||||
cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address)))
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
@ -446,8 +451,7 @@ class TestTransactionsConvenientAPI(IntegrationTest):
|
||||
@client_context.require_transactions
|
||||
def test_callback_not_retried_after_timeout(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_client(event_listeners=[listener])
|
||||
coll = client[self.db.name].test
|
||||
|
||||
def callback(session):
|
||||
@ -475,8 +479,7 @@ class TestTransactionsConvenientAPI(IntegrationTest):
|
||||
@client_context.require_transactions
|
||||
def test_callback_not_retried_after_commit_timeout(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_client(event_listeners=[listener])
|
||||
coll = client[self.db.name].test
|
||||
|
||||
def callback(session):
|
||||
@ -508,8 +511,7 @@ class TestTransactionsConvenientAPI(IntegrationTest):
|
||||
@client_context.require_transactions
|
||||
def test_commit_not_retried_after_timeout(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client = self.rs_client(event_listeners=[listener])
|
||||
coll = client[self.db.name].test
|
||||
|
||||
def callback(session):
|
||||
|
||||
@ -68,8 +68,7 @@ except ImportError:
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context
|
||||
from test.utils import rs_or_single_client
|
||||
from test import IntegrationTest, PyMongoTestCase, client_context
|
||||
|
||||
from bson import CodecOptions, decode, decode_all, decode_file_iter, decode_iter, encode
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
@ -194,7 +193,7 @@ class TestPymongo(IntegrationTest):
|
||||
value.items()
|
||||
|
||||
def test_default_document_type(self) -> None:
|
||||
client = rs_or_single_client()
|
||||
client = self.rs_or_single_client()
|
||||
self.addCleanup(client.close)
|
||||
coll = client.test.test
|
||||
doc = {"my": "doc"}
|
||||
@ -366,7 +365,7 @@ class TestDecode(unittest.TestCase):
|
||||
doc["a"] = 2
|
||||
|
||||
|
||||
class TestDocumentType(unittest.TestCase):
|
||||
class TestDocumentType(PyMongoTestCase):
|
||||
@only_type_check
|
||||
def test_default(self) -> None:
|
||||
client: MongoClient = MongoClient()
|
||||
@ -480,7 +479,7 @@ class TestDocumentType(unittest.TestCase):
|
||||
def test_typeddict_find_notrequired(self):
|
||||
if NotRequired is None or ImplicitMovie is None:
|
||||
raise unittest.SkipTest("Python 3.11+ is required to use NotRequired.")
|
||||
client: MongoClient[ImplicitMovie] = rs_or_single_client()
|
||||
client: MongoClient[ImplicitMovie] = self.rs_or_single_client()
|
||||
coll = client.test.test
|
||||
coll.insert_one(ImplicitMovie(name="THX-1138", year=1971))
|
||||
out = coll.find_one({})
|
||||
|
||||
@ -20,7 +20,7 @@ sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
from test.utils import OvertCommandListener, rs_or_single_client
|
||||
from test.utils import OvertCommandListener
|
||||
|
||||
from pymongo.server_api import ServerApi, ServerApiVersion
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
@ -77,7 +77,7 @@ class TestServerApi(IntegrationTest):
|
||||
@client_context.require_version_min(4, 7)
|
||||
def test_command_options(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener])
|
||||
client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
coll = client.test.test
|
||||
coll.insert_many([{} for _ in range(100)])
|
||||
@ -90,7 +90,7 @@ class TestServerApi(IntegrationTest):
|
||||
@client_context.require_transactions
|
||||
def test_command_options_txn(self):
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener])
|
||||
client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
coll = client.test.test
|
||||
coll.insert_many([{} for _ in range(100)])
|
||||
|
||||
@ -55,8 +55,6 @@ from test.utils import (
|
||||
parse_collection_options,
|
||||
parse_spec_options,
|
||||
prepare_spec_arguments,
|
||||
rs_or_single_client,
|
||||
single_client,
|
||||
snake_to_camel,
|
||||
wait_until,
|
||||
)
|
||||
@ -574,7 +572,7 @@ class EntityMapUtil:
|
||||
)
|
||||
if uri:
|
||||
kwargs["h"] = uri
|
||||
client = rs_or_single_client(**kwargs)
|
||||
client = self.test.rs_or_single_client(**kwargs)
|
||||
self[spec["id"]] = client
|
||||
self.test.addCleanup(client.close)
|
||||
return
|
||||
@ -1115,7 +1113,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
and not client_context.serverless
|
||||
):
|
||||
for address in client_context.mongoses:
|
||||
cls.mongos_clients.append(single_client("{}:{}".format(*address)))
|
||||
cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address)))
|
||||
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(
|
||||
@ -1646,7 +1644,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
)
|
||||
)
|
||||
|
||||
client = single_client("{}:{}".format(*session._pinned_address))
|
||||
client = self.single_client("{}:{}".format(*session._pinned_address))
|
||||
self.addCleanup(client.close)
|
||||
self.__set_fail_point(client=client, command_args=spec["failPoint"])
|
||||
|
||||
|
||||
159
test/utils.py
159
test/utils.py
@ -565,151 +565,6 @@ class SpecTestCreator:
|
||||
setattr(self._test_class, new_test.__name__, new_test)
|
||||
|
||||
|
||||
def _connection_string(h):
|
||||
if h.startswith(("mongodb://", "mongodb+srv://")):
|
||||
return h
|
||||
return f"mongodb://{h!s}"
|
||||
|
||||
|
||||
def _mongo_client(host, port, authenticate=True, directConnection=None, **kwargs):
|
||||
"""Create a new client over SSL/TLS if necessary."""
|
||||
host = host or client_context.host
|
||||
port = port or client_context.port
|
||||
client_options: dict = client_context.default_client_options.copy()
|
||||
if client_context.replica_set_name and not directConnection:
|
||||
client_options["replicaSet"] = client_context.replica_set_name
|
||||
if directConnection is not None:
|
||||
client_options["directConnection"] = directConnection
|
||||
client_options.update(kwargs)
|
||||
|
||||
uri = _connection_string(host)
|
||||
auth_mech = kwargs.get("authMechanism", "")
|
||||
if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
|
||||
# Only add the default username or password if one is not provided.
|
||||
res = parse_uri(uri)
|
||||
if (
|
||||
not res["username"]
|
||||
and not res["password"]
|
||||
and "username" not in client_options
|
||||
and "password" not in client_options
|
||||
):
|
||||
client_options["username"] = db_user
|
||||
client_options["password"] = db_pwd
|
||||
return MongoClient(uri, port, **client_options)
|
||||
|
||||
|
||||
async def _async_mongo_client(host, port, authenticate=True, directConnection=None, **kwargs):
|
||||
"""Create a new client over SSL/TLS if necessary."""
|
||||
host = host or await async_client_context.host
|
||||
port = port or await async_client_context.port
|
||||
client_options: dict = async_client_context.default_client_options.copy()
|
||||
if async_client_context.replica_set_name and not directConnection:
|
||||
client_options["replicaSet"] = async_client_context.replica_set_name
|
||||
if directConnection is not None:
|
||||
client_options["directConnection"] = directConnection
|
||||
client_options.update(kwargs)
|
||||
|
||||
uri = _connection_string(host)
|
||||
auth_mech = kwargs.get("authMechanism", "")
|
||||
if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
|
||||
# Only add the default username or password if one is not provided.
|
||||
res = parse_uri(uri)
|
||||
if (
|
||||
not res["username"]
|
||||
and not res["password"]
|
||||
and "username" not in client_options
|
||||
and "password" not in client_options
|
||||
):
|
||||
client_options["username"] = db_user
|
||||
client_options["password"] = db_pwd
|
||||
client = AsyncMongoClient(uri, port, **client_options)
|
||||
if client._options.connect:
|
||||
await client.aconnect()
|
||||
return client
|
||||
|
||||
|
||||
def single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
|
||||
"""Make a direct connection. Don't authenticate."""
|
||||
return _mongo_client(h, p, authenticate=False, directConnection=True, **kwargs)
|
||||
|
||||
|
||||
def single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
|
||||
"""Make a direct connection, and authenticate if necessary."""
|
||||
return _mongo_client(h, p, directConnection=True, **kwargs)
|
||||
|
||||
|
||||
def rs_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
|
||||
"""Connect to the replica set. Don't authenticate."""
|
||||
return _mongo_client(h, p, authenticate=False, **kwargs)
|
||||
|
||||
|
||||
def rs_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
|
||||
"""Connect to the replica set and authenticate if necessary."""
|
||||
return _mongo_client(h, p, **kwargs)
|
||||
|
||||
|
||||
def rs_or_single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
|
||||
"""Connect to the replica set if there is one, otherwise the standalone.
|
||||
|
||||
Like rs_or_single_client, but does not authenticate.
|
||||
"""
|
||||
return _mongo_client(h, p, authenticate=False, **kwargs)
|
||||
|
||||
|
||||
def rs_or_single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[Any]:
|
||||
"""Connect to the replica set if there is one, otherwise the standalone.
|
||||
|
||||
Authenticates if necessary.
|
||||
"""
|
||||
return _mongo_client(h, p, **kwargs)
|
||||
|
||||
|
||||
async def async_single_client_noauth(
|
||||
h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Make a direct connection. Don't authenticate."""
|
||||
return await _async_mongo_client(h, p, authenticate=False, directConnection=True, **kwargs)
|
||||
|
||||
|
||||
async def async_single_client(
|
||||
h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Make a direct connection, and authenticate if necessary."""
|
||||
return await _async_mongo_client(h, p, directConnection=True, **kwargs)
|
||||
|
||||
|
||||
async def async_rs_client_noauth(
|
||||
h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Connect to the replica set. Don't authenticate."""
|
||||
return await _async_mongo_client(h, p, authenticate=False, **kwargs)
|
||||
|
||||
|
||||
async def async_rs_client(h: Any = None, p: Any = None, **kwargs: Any) -> AsyncMongoClient[dict]:
|
||||
"""Connect to the replica set and authenticate if necessary."""
|
||||
return await _async_mongo_client(h, p, **kwargs)
|
||||
|
||||
|
||||
async def async_rs_or_single_client_noauth(
|
||||
h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[dict]:
|
||||
"""Connect to the replica set if there is one, otherwise the standalone.
|
||||
|
||||
Like rs_or_single_client, but does not authenticate.
|
||||
"""
|
||||
return await _async_mongo_client(h, p, authenticate=False, **kwargs)
|
||||
|
||||
|
||||
async def async_rs_or_single_client(
|
||||
h: Any = None, p: Any = None, **kwargs: Any
|
||||
) -> AsyncMongoClient[Any]:
|
||||
"""Connect to the replica set if there is one, otherwise the standalone.
|
||||
|
||||
Authenticates if necessary.
|
||||
"""
|
||||
return await _async_mongo_client(h, p, **kwargs)
|
||||
|
||||
|
||||
def ensure_all_connected(client: MongoClient) -> None:
|
||||
"""Ensure that the client's connection pool has socket connections to all
|
||||
members of a replica set. Raises ConfigurationError when called with a
|
||||
@ -1108,20 +963,6 @@ def is_greenthread_patched():
|
||||
return gevent_monkey_patched() or eventlet_monkey_patched()
|
||||
|
||||
|
||||
def disable_replication(client):
|
||||
"""Disable replication on all secondaries."""
|
||||
for host, port in client.secondaries:
|
||||
secondary = single_client(host, port)
|
||||
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn")
|
||||
|
||||
|
||||
def enable_replication(client):
|
||||
"""Enable replication on all secondaries."""
|
||||
for host, port in client.secondaries:
|
||||
secondary = single_client(host, port)
|
||||
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off")
|
||||
|
||||
|
||||
class ExceptionCatchingThread(threading.Thread):
|
||||
"""A thread that stores any exception encountered from run()."""
|
||||
|
||||
|
||||
@ -29,7 +29,6 @@ from test.utils import (
|
||||
camel_to_snake_args,
|
||||
parse_spec_options,
|
||||
prepare_spec_arguments,
|
||||
rs_client,
|
||||
)
|
||||
from typing import List
|
||||
|
||||
@ -101,6 +100,8 @@ class SpecRunner(IntegrationTest):
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
for client in cls.mongos_clients:
|
||||
client.close()
|
||||
super()._tearDown_class()
|
||||
|
||||
def setUp(self):
|
||||
@ -524,7 +525,7 @@ class SpecRunner(IntegrationTest):
|
||||
host = client_context.MULTI_MONGOS_LB_URI
|
||||
elif client_context.is_mongos:
|
||||
host = client_context.mongos_seeds()
|
||||
client = rs_client(
|
||||
client = self.rs_client(
|
||||
h=host, event_listeners=[listener, pool_listener, server_listener], **client_options
|
||||
)
|
||||
self.scenario_client = client
|
||||
|
||||
Loading…
Reference in New Issue
Block a user