diff --git a/doc/changelog.rst b/doc/changelog.rst index 6fffcdf69..c008066c2 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -42,6 +42,16 @@ PyMongo 4.9 brings a number of improvements including: - Fixed a bug where PyMongo would raise ``InvalidBSON: date value out of range`` when using :attr:`~bson.codec_options.DatetimeConversion.DATETIME_CLAMP` or :attr:`~bson.codec_options.DatetimeConversion.DATETIME_AUTO` with a non-UTC timezone. +- Added a warning to unclosed MongoClient instances + telling users to explicitly close clients when finished with them to avoid leaking resources. + For example: + + .. code-block:: + + sys:1: ResourceWarning: Unclosed MongoClient opened at: + File "/Users//my_file.py", line 8, in `` + client = MongoClient() + Call MongoClient.close() to safely shut down your client and free up resources. - The default value for ``connect`` in ``MongoClient`` is changed to ``False`` when running on unction-as-a-service (FaaS) like AWS Lambda, Google Cloud Functions, and Microsoft Azure Functions. On some FaaS systems, there is a ``fork()`` operation at function diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index b5e73e8de..a84fbf2e5 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -34,6 +34,7 @@ from __future__ import annotations import contextlib import os +import warnings import weakref from collections import defaultdict from typing import ( @@ -871,6 +872,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): ) self._opened = False + self._closed = False self._init_background() if _IS_SYNC and connect: @@ -1180,6 +1182,22 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): """ return database.AsyncDatabase(self, name) + def __del__(self) -> None: + """Check that this AsyncMongoClient has been closed and issue a warning if not.""" + try: + if not self._closed: + warnings.warn( + ( + f"Unclosed {type(self).__name__} opened at:\n{self._topology_settings._stack}" + f"Call {type(self).__name__}.close() to safely shut down your client and free up resources." + ), + ResourceWarning, + stacklevel=2, + source=self, + ) + except AttributeError: + pass + def _close_cursor_soon( self, cursor_id: int, @@ -1547,6 +1565,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): if self._encrypter: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. await self._encrypter.close() + self._closed = True if not _IS_SYNC: # Add support for contextlib.aclosing. diff --git a/pymongo/asynchronous/settings.py b/pymongo/asynchronous/settings.py index c41c638e6..1103e1bd1 100644 --- a/pymongo/asynchronous/settings.py +++ b/pymongo/asynchronous/settings.py @@ -82,7 +82,7 @@ class TopologySettings: self._topology_id = ObjectId() # Store the allocation traceback to catch unclosed clients in the # test suite. - self._stack = "".join(traceback.format_stack()) + self._stack = "".join(traceback.format_stack()[:-2]) @property def seeds(self) -> Collection[tuple[str, int]]: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 26af488ac..cec78463b 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -34,6 +34,7 @@ from __future__ import annotations import contextlib import os +import warnings import weakref from collections import defaultdict from typing import ( @@ -871,6 +872,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): ) self._opened = False + self._closed = False self._init_background() if _IS_SYNC and connect: @@ -1180,6 +1182,22 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): """ return database.Database(self, name) + def __del__(self) -> None: + """Check that this MongoClient has been closed and issue a warning if not.""" + try: + if not self._closed: + warnings.warn( + ( + f"Unclosed {type(self).__name__} opened at:\n{self._topology_settings._stack}" + f"Call {type(self).__name__}.close() to safely shut down your client and free up resources." + ), + ResourceWarning, + stacklevel=2, + source=self, + ) + except AttributeError: + pass + def _close_cursor_soon( self, cursor_id: int, @@ -1543,6 +1561,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): if self._encrypter: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. self._encrypter.close() + self._closed = True if not _IS_SYNC: # Add support for contextlib.closing. diff --git a/pymongo/synchronous/settings.py b/pymongo/synchronous/settings.py index 8719e8608..040776713 100644 --- a/pymongo/synchronous/settings.py +++ b/pymongo/synchronous/settings.py @@ -82,7 +82,7 @@ class TopologySettings: self._topology_id = ObjectId() # Store the allocation traceback to catch unclosed clients in the # test suite. - self._stack = "".join(traceback.format_stack()) + self._stack = "".join(traceback.format_stack()[:-2]) @property def seeds(self) -> Collection[tuple[str, int]]: diff --git a/pyproject.toml b/pyproject.toml index 225be8e1d..19db00f19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,9 @@ filterwarnings = [ "module:please use dns.resolver.Resolver.resolve:DeprecationWarning", # https://github.com/dateutil/dateutil/issues/1314 "module:datetime.datetime.utc:DeprecationWarning:dateutil", + # TODO: Remove both of these in https://jira.mongodb.org/browse/PYTHON-4731 + "ignore:Unclosed AsyncMongoClient*", + "ignore:Unclosed MongoClient*", ] markers = [ "auth_aws: tests that rely on pymongo-auth-aws", diff --git a/test/asynchronous/test_auth.py b/test/asynchronous/test_auth.py new file mode 100644 index 000000000..e516ff679 --- /dev/null +++ b/test/asynchronous/test_auth.py @@ -0,0 +1,689 @@ +# Copyright 2013-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Authentication Tests.""" +from __future__ import annotations + +import asyncio +import os +import sys +import threading +from urllib.parse import quote_plus + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, SkipTest, async_client_context, unittest +from test.utils import ( + AllowListEventListener, + async_rs_or_single_client, + async_rs_or_single_client_noauth, + async_single_client, + async_single_client_noauth, + delay, + ignore_deprecations, +) + +from pymongo import AsyncMongoClient, monitoring +from pymongo.asynchronous.auth import HAVE_KERBEROS +from pymongo.auth_shared import _build_credentials_tuple +from pymongo.errors import OperationFailure +from pymongo.hello import HelloCompat +from pymongo.read_preferences import ReadPreference +from pymongo.saslprep import HAVE_STRINGPREP + +_IS_SYNC = False + +# YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS ON UNIX. +GSSAPI_HOST = os.environ.get("GSSAPI_HOST") +GSSAPI_PORT = int(os.environ.get("GSSAPI_PORT", "27017")) +GSSAPI_PRINCIPAL = os.environ.get("GSSAPI_PRINCIPAL") +GSSAPI_SERVICE_NAME = os.environ.get("GSSAPI_SERVICE_NAME", "mongodb") +GSSAPI_CANONICALIZE = os.environ.get("GSSAPI_CANONICALIZE", "false") +GSSAPI_SERVICE_REALM = os.environ.get("GSSAPI_SERVICE_REALM") +GSSAPI_PASS = os.environ.get("GSSAPI_PASS") +GSSAPI_DB = os.environ.get("GSSAPI_DB", "test") + +SASL_HOST = os.environ.get("SASL_HOST") +SASL_PORT = int(os.environ.get("SASL_PORT", "27017")) +SASL_USER = os.environ.get("SASL_USER") +SASL_PASS = os.environ.get("SASL_PASS") +SASL_DB = os.environ.get("SASL_DB", "$external") + + +class AutoAuthenticateThread(threading.Thread): + """Used in testing threaded authentication. + + This does await collection.find_one() with a 1-second delay to ensure it must + check out and authenticate multiple connections from the pool concurrently. + + :Parameters: + `collection`: An auth-protected collection containing one document. + """ + + def __init__(self, collection): + super().__init__() + self.collection = collection + self.success = False + + def run(self): + assert self.collection.find_one({"$where": delay(1)}) is not None + self.success = True + + +class TestGSSAPI(unittest.IsolatedAsyncioTestCase): + mech_properties: str + service_realm_required: bool + + @classmethod + def setUpClass(cls): + if not HAVE_KERBEROS: + raise SkipTest("Kerberos module not available.") + if not GSSAPI_HOST or not GSSAPI_PRINCIPAL: + raise SkipTest("Must set GSSAPI_HOST and GSSAPI_PRINCIPAL to test GSSAPI") + cls.service_realm_required = ( + GSSAPI_SERVICE_REALM is not None and GSSAPI_SERVICE_REALM not in GSSAPI_PRINCIPAL + ) + mech_properties = f"SERVICE_NAME:{GSSAPI_SERVICE_NAME}" + mech_properties += f",CANONICALIZE_HOST_NAME:{GSSAPI_CANONICALIZE}" + if GSSAPI_SERVICE_REALM is not None: + mech_properties += f",SERVICE_REALM:{GSSAPI_SERVICE_REALM}" + cls.mech_properties = mech_properties + + async def test_credentials_hashing(self): + # GSSAPI credentials are properly hashed. + creds0 = _build_credentials_tuple("GSSAPI", None, "user", "pass", {}, None) + + creds1 = _build_credentials_tuple( + "GSSAPI", None, "user", "pass", {"authmechanismproperties": {"SERVICE_NAME": "A"}}, None + ) + + creds2 = _build_credentials_tuple( + "GSSAPI", None, "user", "pass", {"authmechanismproperties": {"SERVICE_NAME": "A"}}, None + ) + + creds3 = _build_credentials_tuple( + "GSSAPI", None, "user", "pass", {"authmechanismproperties": {"SERVICE_NAME": "B"}}, None + ) + + self.assertEqual(1, len({creds1, creds2})) + self.assertEqual(3, len({creds0, creds1, creds2, creds3})) + + @ignore_deprecations + async def test_gssapi_simple(self): + assert GSSAPI_PRINCIPAL is not None + if GSSAPI_PASS is not None: + uri = "mongodb://%s:%s@%s:%d/?authMechanism=GSSAPI" % ( + quote_plus(GSSAPI_PRINCIPAL), + GSSAPI_PASS, + GSSAPI_HOST, + GSSAPI_PORT, + ) + else: + uri = "mongodb://%s@%s:%d/?authMechanism=GSSAPI" % ( + quote_plus(GSSAPI_PRINCIPAL), + GSSAPI_HOST, + GSSAPI_PORT, + ) + + if not self.service_realm_required: + # Without authMechanismProperties. + client = AsyncMongoClient( + GSSAPI_HOST, + GSSAPI_PORT, + username=GSSAPI_PRINCIPAL, + password=GSSAPI_PASS, + authMechanism="GSSAPI", + ) + + await client[GSSAPI_DB].collection.find_one() + + # Log in using URI, without authMechanismProperties. + client = AsyncMongoClient(uri) + await client[GSSAPI_DB].collection.find_one() + + # Authenticate with authMechanismProperties. + client = AsyncMongoClient( + GSSAPI_HOST, + GSSAPI_PORT, + username=GSSAPI_PRINCIPAL, + password=GSSAPI_PASS, + authMechanism="GSSAPI", + authMechanismProperties=self.mech_properties, + ) + + await client[GSSAPI_DB].collection.find_one() + + # Log in using URI, with authMechanismProperties. + mech_uri = uri + f"&authMechanismProperties={self.mech_properties}" + client = AsyncMongoClient(mech_uri) + await client[GSSAPI_DB].collection.find_one() + + set_name = await client.admin.command(HelloCompat.LEGACY_CMD).get("setName") + if set_name: + if not self.service_realm_required: + # Without authMechanismProperties + client = AsyncMongoClient( + GSSAPI_HOST, + GSSAPI_PORT, + username=GSSAPI_PRINCIPAL, + password=GSSAPI_PASS, + authMechanism="GSSAPI", + replicaSet=set_name, + ) + + await client[GSSAPI_DB].list_collection_names() + + uri = uri + f"&replicaSet={set_name!s}" + client = AsyncMongoClient(uri) + await client[GSSAPI_DB].list_collection_names() + + # With authMechanismProperties + client = AsyncMongoClient( + GSSAPI_HOST, + GSSAPI_PORT, + username=GSSAPI_PRINCIPAL, + password=GSSAPI_PASS, + authMechanism="GSSAPI", + authMechanismProperties=self.mech_properties, + replicaSet=set_name, + ) + + await client[GSSAPI_DB].list_collection_names() + + mech_uri = mech_uri + f"&replicaSet={set_name!s}" + client = AsyncMongoClient(mech_uri) + await client[GSSAPI_DB].list_collection_names() + + @ignore_deprecations + @async_client_context.require_sync + async def test_gssapi_threaded(self): + client = AsyncMongoClient( + GSSAPI_HOST, + GSSAPI_PORT, + username=GSSAPI_PRINCIPAL, + password=GSSAPI_PASS, + authMechanism="GSSAPI", + authMechanismProperties=self.mech_properties, + ) + + # Authentication succeeded? + await client.server_info() + db = client[GSSAPI_DB] + + # Need one document in the collection. AutoAuthenticateThread does + # collection.find_one with a 1-second delay, forcing it to check out + # multiple connections from the pool concurrently, proving that + # auto-authentication works with GSSAPI. + collection = db.test + if not await collection.count_documents({}): + try: + await collection.drop() + await collection.insert_one({"_id": 1}) + except OperationFailure: + raise SkipTest("User must be able to write.") + + threads = [] + for _ in range(4): + threads.append(AutoAuthenticateThread(collection)) + for thread in threads: + thread.start() + for thread in threads: + thread.join() + self.assertTrue(thread.success) + + set_name = await client.admin.command(HelloCompat.LEGACY_CMD).get("setName") + if set_name: + client = AsyncMongoClient( + GSSAPI_HOST, + GSSAPI_PORT, + username=GSSAPI_PRINCIPAL, + password=GSSAPI_PASS, + authMechanism="GSSAPI", + authMechanismProperties=self.mech_properties, + replicaSet=set_name, + ) + + # Succeeded? + await client.server_info() + + threads = [] + for _ in range(4): + threads.append(AutoAuthenticateThread(collection)) + for thread in threads: + thread.start() + for thread in threads: + thread.join() + self.assertTrue(thread.success) + + +class TestSASLPlain(unittest.IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls): + if not SASL_HOST or not SASL_USER or not SASL_PASS: + raise SkipTest("Must set SASL_HOST, SASL_USER, and SASL_PASS to test SASL") + + async def test_sasl_plain(self): + client = AsyncMongoClient( + SASL_HOST, + SASL_PORT, + username=SASL_USER, + password=SASL_PASS, + authSource=SASL_DB, + authMechanism="PLAIN", + ) + await client.ldap.test.find_one() + + assert SASL_USER is not None + assert SASL_PASS is not None + uri = "mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;authSource=%s" % ( + quote_plus(SASL_USER), + quote_plus(SASL_PASS), + SASL_HOST, + SASL_PORT, + SASL_DB, + ) + client = AsyncMongoClient(uri) + await client.ldap.test.find_one() + + set_name = await client.admin.command(HelloCompat.LEGACY_CMD).get("setName") + if set_name: + client = AsyncMongoClient( + SASL_HOST, + SASL_PORT, + replicaSet=set_name, + username=SASL_USER, + password=SASL_PASS, + authSource=SASL_DB, + authMechanism="PLAIN", + ) + await client.ldap.test.find_one() + + uri = "mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;authSource=%s;replicaSet=%s" % ( + quote_plus(SASL_USER), + quote_plus(SASL_PASS), + SASL_HOST, + SASL_PORT, + SASL_DB, + str(set_name), + ) + client = AsyncMongoClient(uri) + await client.ldap.test.find_one() + + async def test_sasl_plain_bad_credentials(self): + def auth_string(user, password): + uri = "mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;authSource=%s" % ( + quote_plus(user), + quote_plus(password), + SASL_HOST, + SASL_PORT, + SASL_DB, + ) + return uri + + bad_user = AsyncMongoClient(auth_string("not-user", SASL_PASS)) + bad_pwd = AsyncMongoClient(auth_string(SASL_USER, "not-pwd")) + # OperationFailure raised upon connecting. + with self.assertRaises(OperationFailure): + await bad_user.admin.command("ping") + with self.assertRaises(OperationFailure): + await bad_pwd.admin.command("ping") + + +class TestSCRAMSHA1(AsyncIntegrationTest): + @async_client_context.require_auth + async def asyncSetUp(self): + await super().asyncSetUp() + await async_client_context.create_user( + "pymongo_test", "user", "pass", roles=["userAdmin", "readWrite"] + ) + + async def asyncTearDown(self): + await async_client_context.drop_user("pymongo_test", "user") + await super().asyncTearDown() + + @async_client_context.require_no_fips + async def test_scram_sha1(self): + host, port = await async_client_context.host, await async_client_context.port + + client = await async_rs_or_single_client_noauth( + "mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" % (host, port) + ) + await client.pymongo_test.command("dbstats") + + if async_client_context.is_rs: + uri = ( + "mongodb://user:pass" + "@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" + "&replicaSet=%s" % (host, port, async_client_context.replica_set_name) + ) + client = await async_single_client_noauth(uri) + await client.pymongo_test.command("dbstats") + db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) + await db.command("dbstats") + + +# https://github.com/mongodb/specifications/blob/master/source/auth/auth.rst#scram-sha-256-and-mechanism-negotiation +class TestSCRAM(AsyncIntegrationTest): + @async_client_context.require_auth + @async_client_context.require_version_min(3, 7, 2) + async def asyncSetUp(self): + await super().asyncSetUp() + self._SENSITIVE_COMMANDS = monitoring._SENSITIVE_COMMANDS + monitoring._SENSITIVE_COMMANDS = set() + self.listener = AllowListEventListener("saslStart") + + async def asyncTearDown(self): + monitoring._SENSITIVE_COMMANDS = self._SENSITIVE_COMMANDS + await async_client_context.client.testscram.command("dropAllUsersFromDatabase") + await async_client_context.client.drop_database("testscram") + await super().asyncTearDown() + + async def test_scram_skip_empty_exchange(self): + listener = AllowListEventListener("saslStart", "saslContinue") + await async_client_context.create_user( + "testscram", "sha256", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] + ) + + client = await async_rs_or_single_client_noauth( + username="sha256", password="pwd", authSource="testscram", event_listeners=[listener] + ) + await client.testscram.command("dbstats") + + if async_client_context.version < (4, 4, -1): + # Assert we sent the skipEmptyExchange option. + first_event = listener.started_events[0] + self.assertEqual(first_event.command_name, "saslStart") + self.assertEqual(first_event.command["options"], {"skipEmptyExchange": True}) + + # Assert the third exchange was skipped on servers that support it. + # Note that the first exchange occurs on the connection handshake. + started = listener.started_command_names() + if async_client_context.version.at_least(4, 4, -1): + self.assertEqual(started, ["saslContinue"]) + else: + self.assertEqual(started, ["saslStart", "saslContinue", "saslContinue"]) + + @async_client_context.require_no_fips + async def test_scram(self): + # Step 1: create users + await async_client_context.create_user( + "testscram", "sha1", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-1"] + ) + await async_client_context.create_user( + "testscram", "sha256", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] + ) + await async_client_context.create_user( + "testscram", + "both", + "pwd", + roles=["dbOwner"], + mechanisms=["SCRAM-SHA-1", "SCRAM-SHA-256"], + ) + + # Step 2: verify auth success cases + client = await async_rs_or_single_client_noauth( + username="sha1", password="pwd", authSource="testscram" + ) + await client.testscram.command("dbstats") + + client = await async_rs_or_single_client_noauth( + username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" + ) + await client.testscram.command("dbstats") + + client = await async_rs_or_single_client_noauth( + username="sha256", password="pwd", authSource="testscram" + ) + await client.testscram.command("dbstats") + + client = await async_rs_or_single_client_noauth( + username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" + ) + await client.testscram.command("dbstats") + + # Step 2: SCRAM-SHA-1 and SCRAM-SHA-256 + client = await async_rs_or_single_client_noauth( + username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" + ) + await client.testscram.command("dbstats") + client = await async_rs_or_single_client_noauth( + username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" + ) + await client.testscram.command("dbstats") + + self.listener.reset() + client = await async_rs_or_single_client_noauth( + username="both", password="pwd", authSource="testscram", event_listeners=[self.listener] + ) + await client.testscram.command("dbstats") + if async_client_context.version.at_least(4, 4, -1): + # Speculative authentication in 4.4+ sends saslStart with the + # handshake. + self.assertEqual(self.listener.started_events, []) + else: + started = self.listener.started_events[0] + self.assertEqual(started.command.get("mechanism"), "SCRAM-SHA-256") + + # Step 3: verify auth failure conditions + client = await async_rs_or_single_client_noauth( + username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" + ) + with self.assertRaises(OperationFailure): + await client.testscram.command("dbstats") + + client = await async_rs_or_single_client_noauth( + username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" + ) + with self.assertRaises(OperationFailure): + await client.testscram.command("dbstats") + + client = await async_rs_or_single_client_noauth( + username="not-a-user", password="pwd", authSource="testscram" + ) + with self.assertRaises(OperationFailure): + await client.testscram.command("dbstats") + + if async_client_context.is_rs: + host, port = await async_client_context.host, await async_client_context.port + uri = "mongodb://both:pwd@%s:%d/testscram?replicaSet=%s" % ( + host, + port, + async_client_context.replica_set_name, + ) + client = await async_single_client_noauth(uri) + await client.testscram.command("dbstats") + db = client.get_database("testscram", read_preference=ReadPreference.SECONDARY) + await db.command("dbstats") + + @unittest.skipUnless(HAVE_STRINGPREP, "Cannot test without stringprep") + async def test_scram_saslprep(self): + # Step 4: test SASLprep + host, port = await async_client_context.host, await async_client_context.port + # Test the use of SASLprep on passwords. For example, + # saslprep('\u2136') becomes 'IV' and saslprep('I\u00ADX') + # becomes 'IX'. SASLprep is only supported when the standard + # library provides stringprep. + await async_client_context.create_user( + "testscram", "\u2168", "\u2163", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] + ) + await async_client_context.create_user( + "testscram", "IX", "IX", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] + ) + + client = await async_rs_or_single_client_noauth( + username="\u2168", password="\u2163", authSource="testscram" + ) + await client.testscram.command("dbstats") + + client = await async_rs_or_single_client_noauth( + username="\u2168", + password="\u2163", + authSource="testscram", + authMechanism="SCRAM-SHA-256", + ) + await client.testscram.command("dbstats") + + client = await async_rs_or_single_client_noauth( + username="\u2168", password="IV", authSource="testscram" + ) + await client.testscram.command("dbstats") + + client = await async_rs_or_single_client_noauth( + username="IX", password="I\u00ADX", authSource="testscram" + ) + await client.testscram.command("dbstats") + + client = await async_rs_or_single_client_noauth( + username="IX", + password="I\u00ADX", + authSource="testscram", + authMechanism="SCRAM-SHA-256", + ) + await client.testscram.command("dbstats") + + client = await async_rs_or_single_client_noauth( + username="IX", password="IX", authSource="testscram", authMechanism="SCRAM-SHA-256" + ) + await client.testscram.command("dbstats") + + client = await async_rs_or_single_client_noauth( + "mongodb://\u2168:\u2163@%s:%d/testscram" % (host, port) + ) + await client.testscram.command("dbstats") + client = await async_rs_or_single_client_noauth( + "mongodb://\u2168:IV@%s:%d/testscram" % (host, port) + ) + await client.testscram.command("dbstats") + + client = await async_rs_or_single_client_noauth( + "mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port) + ) + await client.testscram.command("dbstats") + client = await async_rs_or_single_client_noauth( + "mongodb://IX:IX@%s:%d/testscram" % (host, port) + ) + await client.testscram.command("dbstats") + + async def test_cache(self): + client = await async_single_client() + credentials = client.options.pool_options._credentials + cache = credentials.cache + self.assertIsNotNone(cache) + self.assertIsNone(cache.data) + # Force authentication. + await client.admin.command("ping") + cache = credentials.cache + self.assertIsNotNone(cache) + data = cache.data + self.assertIsNotNone(data) + self.assertEqual(len(data), 4) + ckey, skey, salt, iterations = data + self.assertIsInstance(ckey, bytes) + self.assertIsInstance(skey, bytes) + self.assertIsInstance(salt, bytes) + self.assertIsInstance(iterations, int) + + @async_client_context.require_sync + async def test_scram_threaded(self): + coll = async_client_context.client.db.test + await coll.drop() + await coll.insert_one({"_id": 1}) + + # The first thread to call find() will authenticate + client = await async_rs_or_single_client() + self.addAsyncCleanup(client.close) + coll = client.db.test + threads = [] + for _ in range(4): + threads.append(AutoAuthenticateThread(coll)) + for thread in threads: + thread.start() + for thread in threads: + thread.join() + self.assertTrue(thread.success) + + +class TestAuthURIOptions(AsyncIntegrationTest): + @async_client_context.require_auth + async def asyncSetUp(self): + await super().asyncSetUp() + await async_client_context.create_user("admin", "admin", "pass") + await async_client_context.create_user( + "pymongo_test", "user", "pass", ["userAdmin", "readWrite"] + ) + + async def asyncTearDown(self): + await async_client_context.drop_user("pymongo_test", "user") + await async_client_context.drop_user("admin", "admin") + await super().asyncTearDown() + + async def test_uri_options(self): + # Test default to admin + host, port = await async_client_context.host, await async_client_context.port + client = await async_rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) + self.assertTrue(await client.admin.command("dbstats")) + + if async_client_context.is_rs: + uri = "mongodb://admin:pass@%s:%d/?replicaSet=%s" % ( + host, + port, + async_client_context.replica_set_name, + ) + client = await async_single_client_noauth(uri) + self.assertTrue(await client.admin.command("dbstats")) + db = client.get_database("admin", read_preference=ReadPreference.SECONDARY) + self.assertTrue(await db.command("dbstats")) + + # Test explicit database + uri = "mongodb://user:pass@%s:%d/pymongo_test" % (host, port) + client = await async_rs_or_single_client_noauth(uri) + with self.assertRaises(OperationFailure): + await client.admin.command("dbstats") + self.assertTrue(await client.pymongo_test.command("dbstats")) + + if async_client_context.is_rs: + uri = "mongodb://user:pass@%s:%d/pymongo_test?replicaSet=%s" % ( + host, + port, + async_client_context.replica_set_name, + ) + client = await async_single_client_noauth(uri) + with self.assertRaises(OperationFailure): + await client.admin.command("dbstats") + self.assertTrue(await client.pymongo_test.command("dbstats")) + db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) + self.assertTrue(await db.command("dbstats")) + + # Test authSource + uri = "mongodb://user:pass@%s:%d/pymongo_test2?authSource=pymongo_test" % (host, port) + client = await async_rs_or_single_client_noauth(uri) + with self.assertRaises(OperationFailure): + await client.pymongo_test2.command("dbstats") + self.assertTrue(await client.pymongo_test.command("dbstats")) + + if async_client_context.is_rs: + uri = ( + "mongodb://user:pass@%s:%d/pymongo_test2?replicaSet=" + "%s;authSource=pymongo_test" % (host, port, async_client_context.replica_set_name) + ) + client = await async_single_client_noauth(uri) + with self.assertRaises(OperationFailure): + await client.pymongo_test2.command("dbstats") + self.assertTrue(await client.pymongo_test.command("dbstats")) + db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) + self.assertTrue(await db.command("dbstats")) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index eb431e1d5..030f468db 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -29,6 +29,7 @@ import traceback import uuid import warnings from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context +from test.asynchronous.test_bulk import AsyncBulkTestBase from threading import Thread from typing import Any, Dict, Mapping @@ -52,7 +53,6 @@ from test.helpers import ( KMIP_CREDS, LOCAL_MASTER_KEY, ) -from test.test_bulk import BulkTestBase from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, @@ -372,7 +372,7 @@ class TestClientSimple(AsyncEncryptionIntegrationTest): await target() -class TestEncryptedBulkWrite(BulkTestBase, AsyncEncryptionIntegrationTest): +class TestEncryptedBulkWrite(AsyncBulkTestBase, AsyncEncryptionIntegrationTest): async def test_upsert_uuid_standard_encrypt(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = await async_rs_or_single_client(auto_encryption_opts=opts) diff --git a/test/test_auth.py b/test/test_auth.py index 2ae0eae12..0bf0cfd80 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -15,6 +15,7 @@ """Authentication Tests.""" from __future__ import annotations +import asyncio import os import sys import threading @@ -34,12 +35,14 @@ from test.utils import ( ) from pymongo import MongoClient, monitoring -from pymongo.asynchronous.auth import HAVE_KERBEROS from pymongo.auth_shared import _build_credentials_tuple from pymongo.errors import OperationFailure from pymongo.hello import HelloCompat from pymongo.read_preferences import ReadPreference from pymongo.saslprep import HAVE_STRINGPREP +from pymongo.synchronous.auth import HAVE_KERBEROS + +_IS_SYNC = True # YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS ON UNIX. GSSAPI_HOST = os.environ.get("GSSAPI_HOST") @@ -203,6 +206,7 @@ class TestGSSAPI(unittest.TestCase): client[GSSAPI_DB].list_collection_names() @ignore_deprecations + @client_context.require_sync def test_gssapi_threaded(self): client = MongoClient( GSSAPI_HOST, @@ -330,8 +334,10 @@ class TestSASLPlain(unittest.TestCase): bad_user = MongoClient(auth_string("not-user", SASL_PASS)) bad_pwd = MongoClient(auth_string(SASL_USER, "not-pwd")) # OperationFailure raised upon connecting. - self.assertRaises(OperationFailure, bad_user.admin.command, "ping") - self.assertRaises(OperationFailure, bad_pwd.admin.command, "ping") + with self.assertRaises(OperationFailure): + bad_user.admin.command("ping") + with self.assertRaises(OperationFailure): + bad_pwd.admin.command("ping") class TestSCRAMSHA1(IntegrationTest): @@ -578,6 +584,7 @@ class TestSCRAM(IntegrationTest): self.assertIsInstance(salt, bytes) self.assertIsInstance(iterations, int) + @client_context.require_sync def test_scram_threaded(self): coll = client_context.client.db.test coll.drop() @@ -629,7 +636,8 @@ class TestAuthURIOptions(IntegrationTest): # Test explicit database uri = "mongodb://user:pass@%s:%d/pymongo_test" % (host, port) client = rs_or_single_client_noauth(uri) - self.assertRaises(OperationFailure, client.admin.command, "dbstats") + with self.assertRaises(OperationFailure): + client.admin.command("dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) if client_context.is_rs: @@ -639,7 +647,8 @@ class TestAuthURIOptions(IntegrationTest): client_context.replica_set_name, ) client = single_client_noauth(uri) - self.assertRaises(OperationFailure, client.admin.command, "dbstats") + with self.assertRaises(OperationFailure): + client.admin.command("dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) self.assertTrue(db.command("dbstats")) @@ -647,7 +656,8 @@ class TestAuthURIOptions(IntegrationTest): # Test authSource uri = "mongodb://user:pass@%s:%d/pymongo_test2?authSource=pymongo_test" % (host, port) client = rs_or_single_client_noauth(uri) - self.assertRaises(OperationFailure, client.pymongo_test2.command, "dbstats") + with self.assertRaises(OperationFailure): + client.pymongo_test2.command("dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) if client_context.is_rs: @@ -656,7 +666,8 @@ class TestAuthURIOptions(IntegrationTest): "%s;authSource=pymongo_test" % (host, port, client_context.replica_set_name) ) client = single_client_noauth(uri) - self.assertRaises(OperationFailure, client.pymongo_test2.command, "dbstats") + with self.assertRaises(OperationFailure): + client.pymongo_test2.command("dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) self.assertTrue(db.command("dbstats")) diff --git a/test/test_encryption.py b/test/test_encryption.py index 568ebffc9..5e02e4d62 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -29,6 +29,7 @@ import traceback import uuid import warnings from test import IntegrationTest, PyMongoTestCase, client_context +from test.test_bulk import BulkTestBase from threading import Thread from typing import Any, Dict, Mapping @@ -52,7 +53,6 @@ from test.helpers import ( KMIP_CREDS, LOCAL_MASTER_KEY, ) -from test.test_bulk import BulkTestBase from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, diff --git a/test/unified_format.py b/test/unified_format.py index 168d35ee1..63cd23af8 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1170,9 +1170,6 @@ class UnifiedSpecTestMixinV1(IntegrationTest): self.skipTest("Implement PYTHON-1894") if "timeoutMS applied to entire download" in spec["description"]: self.skipTest("PyMongo's open_download_stream does not cap the stream's lifetime") - if "unpin after non-transient error on abort" in spec["description"]: - if client_context.version[0] == 8: - self.skipTest("Skipping TransientTransactionError pending PYTHON-4182") class_name = self.__class__.__name__.lower() description = spec["description"].lower() diff --git a/tools/synchro.py b/tools/synchro.py index e79cfce40..cde75b539 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -161,6 +161,7 @@ converted_tests = [ "pymongo_mocks.py", "utils_spec_runner.py", "qcheck.py", + "test_auth.py", "test_auth_spec.py", "test_bulk.py", "test_client.py",