Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2025-02-06 09:27:48 -06:00
commit 464bcdeff5
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
9 changed files with 331 additions and 80 deletions

View File

@ -40,7 +40,7 @@ try:
except ImportError:
HAVE_IPADDRESS = False
from functools import wraps
from typing import Any, Callable, Dict, Generator, no_type_check
from typing import Any, Callable, Dict, Generator, Optional, no_type_check
from unittest import SkipTest
from bson.son import SON
@ -395,7 +395,7 @@ class ConcurrentRunner(PARENT):
async def start(self):
self.task = create_task(self.run(), name=self.name)
async def join(self, timeout: float | None = 0): # type: ignore[override]
async def join(self, timeout: Optional[float] = None): # type: ignore[override]
if self.task is not None:
await asyncio.wait([self.task], timeout=timeout)
@ -407,3 +407,18 @@ class ConcurrentRunner(PARENT):
await self.target(*self.args)
finally:
self.stopped = True
class ExceptionCatchingTask(ConcurrentRunner):
"""A Task that stores any exception encountered while running."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.exc = None
async def run(self):
try:
await super().run()
except BaseException as exc:
self.exc = exc
raise

View File

@ -0,0 +1,199 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the Load Balancer unified spec tests."""
from __future__ import annotations
import asyncio
import gc
import os
import pathlib
import sys
import threading
from asyncio import Event
from test.asynchronous.helpers import ConcurrentRunner, ExceptionCatchingTask
import pytest
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.asynchronous.unified_format import generate_test_classes
from test.utils import (
async_get_pool,
async_wait_until,
create_async_event,
)
from pymongo.asynchronous.helpers import anext
_IS_SYNC = False
pytestmark = pytest.mark.load_balancer
# Location of JSON test specifications.
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer")
# Generate unified tests.
globals().update(generate_test_classes(_TEST_PATH, module=__name__))
class TestLB(AsyncIntegrationTest):
RUN_ON_LOAD_BALANCER = True
RUN_ON_SERVERLESS = True
async def test_connections_are_only_returned_once(self):
if "PyPy" in sys.version:
# Tracked in PYTHON-3011
self.skipTest("Test is flaky on PyPy")
pool = await async_get_pool(self.client)
n_conns = len(pool.conns)
await self.db.test.find_one({})
self.assertEqual(len(pool.conns), n_conns)
await (await self.db.test.aggregate([{"$limit": 1}])).to_list()
self.assertEqual(len(pool.conns), n_conns)
@async_client_context.require_load_balancer
async def test_unpin_committed_transaction(self):
client = await self.async_rs_client()
pool = await async_get_pool(client)
coll = client[self.db.name].test
async with client.start_session() as session:
async with await session.start_transaction():
self.assertEqual(pool.active_sockets, 0)
await coll.insert_one({}, session=session)
self.assertEqual(pool.active_sockets, 1) # Pinned.
self.assertEqual(pool.active_sockets, 1) # Still pinned.
self.assertEqual(pool.active_sockets, 0) # Unpinned.
@async_client_context.require_failCommand_fail_point
async def test_cursor_gc(self):
async def create_resource(coll):
cursor = coll.find({}, batch_size=3)
await anext(cursor)
return cursor
await self._test_no_gc_deadlock(create_resource)
@async_client_context.require_failCommand_fail_point
async def test_command_cursor_gc(self):
async def create_resource(coll):
cursor = await coll.aggregate([], batchSize=3)
await anext(cursor)
return cursor
await self._test_no_gc_deadlock(create_resource)
async def _test_no_gc_deadlock(self, create_resource):
client = await self.async_rs_client()
pool = await async_get_pool(client)
coll = client[self.db.name].test
await coll.insert_many([{} for _ in range(10)])
self.assertEqual(pool.active_sockets, 0)
# Cause the initial find attempt to fail to induce a reference cycle.
args = {
"mode": {"times": 1},
"data": {
"failCommands": ["find", "aggregate"],
"closeConnection": True,
},
}
async with self.fail_point(args):
resource = await create_resource(coll)
if async_client_context.load_balancer:
self.assertEqual(pool.active_sockets, 1) # Pinned.
task = PoolLocker(pool)
await task.start()
self.assertTrue(await task.wait(task.locked, 5), "timed out")
# Garbage collect the resource while the pool is locked to ensure we
# don't deadlock.
del resource
# On PyPy it can take a few rounds to collect the cursor.
for _ in range(3):
gc.collect()
task.unlock.set()
await task.join(5)
self.assertFalse(task.is_alive())
self.assertIsNone(task.exc)
await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
# Run another operation to ensure the socket still works.
await coll.delete_many({})
@async_client_context.require_transactions
async def test_session_gc(self):
client = await self.async_rs_client()
pool = await async_get_pool(client)
session = client.start_session()
await session.start_transaction()
await client.test_session_gc.test.find_one({}, session=session)
# Cleanup the transaction left open on the server unless we're
# testing serverless which does not support killSessions.
if not async_client_context.serverless:
self.addAsyncCleanup(self.client.admin.command, "killSessions", [session.session_id])
if async_client_context.load_balancer:
self.assertEqual(pool.active_sockets, 1) # Pinned.
task = PoolLocker(pool)
await task.start()
self.assertTrue(await task.wait(task.locked, 5), "timed out")
# Garbage collect the session while the pool is locked to ensure we
# don't deadlock.
del session
# On PyPy it can take a few rounds to collect the session.
for _ in range(3):
gc.collect()
task.unlock.set()
await task.join(5)
self.assertFalse(task.is_alive())
self.assertIsNone(task.exc)
await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
# Run another operation to ensure the socket still works.
await client[self.db.name].test.delete_many({})
class PoolLocker(ExceptionCatchingTask):
def __init__(self, pool):
super().__init__(target=self.lock_pool)
self.pool = pool
self.daemon = True
self.locked = create_async_event()
self.unlock = create_async_event()
async def lock_pool(self):
async with self.pool.lock:
self.locked.set()
# Wait for the unlock flag.
unlock_pool = await self.wait(self.unlock, 10)
if not unlock_pool:
raise Exception("timed out waiting for unlock signal: deadlock?")
async def wait(self, event: Event, timeout: int):
if _IS_SYNC:
return event.wait(timeout) # type: ignore[call-arg]
else:
try:
await asyncio.wait_for(event.wait(), timeout=timeout)
except asyncio.TimeoutError:
return False
return True
if __name__ == "__main__":
unittest.main()

View File

@ -15,10 +15,13 @@
"""Test the client_session module."""
from __future__ import annotations
import asyncio
import copy
import sys
import time
from asyncio import iscoroutinefunction
from io import BytesIO
from test.asynchronous.helpers import ExceptionCatchingTask
from typing import Any, Callable, List, Set, Tuple
from pymongo.synchronous.mongo_client import MongoClient
@ -35,7 +38,6 @@ from test.asynchronous import (
)
from test.utils import (
EventListener,
ExceptionCatchingThread,
OvertCommandListener,
async_wait_until,
)
@ -184,8 +186,7 @@ class TestSession(AsyncIntegrationTest):
f"{f.__name__} did not return implicit session to pool",
)
@async_client_context.require_sync
def test_implicit_sessions_checkout(self):
async def test_implicit_sessions_checkout(self):
# "To confirm that implicit sessions only allocate their server session after a
# successful connection checkout" test from Driver Sessions Spec.
succeeded = False
@ -193,7 +194,7 @@ class TestSession(AsyncIntegrationTest):
failures = 0
for _ in range(5):
listener = OvertCommandListener()
client = self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
client = await self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
cursor = client.db.test.find({})
ops: List[Tuple[Callable, List[Any]]] = [
(client.db.test.find_one, [{"_id": 1}]),
@ -210,26 +211,27 @@ class TestSession(AsyncIntegrationTest):
(cursor.distinct, ["_id"]),
(client.db.list_collections, []),
]
threads = []
tasks = []
listener.reset()
def thread_target(op, *args):
res = op(*args)
async def target(op, *args):
if iscoroutinefunction(op):
res = await op(*args)
else:
res = op(*args)
if isinstance(res, (AsyncCursor, AsyncCommandCursor)):
list(res) # type: ignore[call-overload]
await res.to_list()
for op, args in ops:
threads.append(
ExceptionCatchingThread(
target=thread_target, args=[op, *args], name=op.__name__
)
tasks.append(
ExceptionCatchingTask(target=target, args=[op, *args], name=op.__name__)
)
threads[-1].start()
self.assertEqual(len(threads), len(ops))
for thread in threads:
thread.join()
self.assertIsNone(thread.exc)
client.close()
await tasks[-1].start()
self.assertEqual(len(tasks), len(ops))
for t in tasks:
await t.join()
self.assertIsNone(t.exc)
await client.close()
lsid_set.clear()
for i in listener.started_events:
if i.command.get("lsid"):

View File

@ -40,7 +40,7 @@ try:
except ImportError:
HAVE_IPADDRESS = False
from functools import wraps
from typing import Any, Callable, Dict, Generator, no_type_check
from typing import Any, Callable, Dict, Generator, Optional, no_type_check
from unittest import SkipTest
from bson.son import SON
@ -395,7 +395,7 @@ class ConcurrentRunner(PARENT):
def start(self):
self.task = create_task(self.run(), name=self.name)
def join(self, timeout: float | None = 0): # type: ignore[override]
def join(self, timeout: Optional[float] = None): # type: ignore[override]
if self.task is not None:
asyncio.wait([self.task], timeout=timeout)
@ -407,3 +407,18 @@ class ConcurrentRunner(PARENT):
self.target(*self.args)
finally:
self.stopped = True
class ExceptionCatchingTask(ConcurrentRunner):
"""A Task that stores any exception encountered while running."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.exc = None
def run(self):
try:
super().run()
except BaseException as exc:
self.exc = exc
raise

View File

@ -33,7 +33,7 @@ from io import BytesIO
sys.path[0:0] = [""]
from test import qcheck, unittest
from test.utils import ExceptionCatchingThread
from test.helpers import ExceptionCatchingTask
import bson
from bson import (
@ -1075,7 +1075,7 @@ class TestBSON(unittest.TestCase):
my_int = type(f"MyInt_{i}_{j}", (int,), {})
bson.encode({"my_int": my_int()})
threads = [ExceptionCatchingThread(target=target, args=(i,)) for i in range(3)]
threads = [ExceptionCatchingTask(target=target, args=(i,)) for i in range(3)]
for t in threads:
t.start()
@ -1114,7 +1114,7 @@ class TestBSON(unittest.TestCase):
def test_doc_in_invalid_document_error_message_mapping(self):
class MyMapping(abc.Mapping):
def keys():
def keys(self):
return ["t"]
def __getitem__(self, name):

View File

@ -15,10 +15,14 @@
"""Test the Load Balancer unified spec tests."""
from __future__ import annotations
import asyncio
import gc
import os
import pathlib
import sys
import threading
from asyncio import Event
from test.helpers import ConcurrentRunner, ExceptionCatchingTask
import pytest
@ -26,15 +30,26 @@ sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test.unified_format import generate_test_classes
from test.utils import ExceptionCatchingThread, get_pool, wait_until
from test.utils import (
create_event,
get_pool,
wait_until,
)
from pymongo.synchronous.helpers import next
_IS_SYNC = True
pytestmark = pytest.mark.load_balancer
# Location of JSON test specifications.
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "load_balancer")
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer")
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__))
globals().update(generate_test_classes(_TEST_PATH, module=__name__))
class TestLB(IntegrationTest):
@ -49,13 +64,12 @@ class TestLB(IntegrationTest):
n_conns = len(pool.conns)
self.db.test.find_one({})
self.assertEqual(len(pool.conns), n_conns)
list(self.db.test.aggregate([{"$limit": 1}]))
(self.db.test.aggregate([{"$limit": 1}])).to_list()
self.assertEqual(len(pool.conns), n_conns)
@client_context.require_load_balancer
def test_unpin_committed_transaction(self):
client = self.rs_client()
self.addCleanup(client.close)
pool = get_pool(client)
coll = client[self.db.name].test
with client.start_session() as session:
@ -86,7 +100,6 @@ class TestLB(IntegrationTest):
def _test_no_gc_deadlock(self, create_resource):
client = self.rs_client()
self.addCleanup(client.close)
pool = get_pool(client)
coll = client[self.db.name].test
coll.insert_many([{} for _ in range(10)])
@ -104,19 +117,19 @@ class TestLB(IntegrationTest):
if client_context.load_balancer:
self.assertEqual(pool.active_sockets, 1) # Pinned.
thread = PoolLocker(pool)
thread.start()
self.assertTrue(thread.locked.wait(5), "timed out")
task = PoolLocker(pool)
task.start()
self.assertTrue(task.wait(task.locked, 5), "timed out")
# Garbage collect the resource while the pool is locked to ensure we
# don't deadlock.
del resource
# On PyPy it can take a few rounds to collect the cursor.
for _ in range(3):
gc.collect()
thread.unlock.set()
thread.join(5)
self.assertFalse(thread.is_alive())
self.assertIsNone(thread.exc)
task.unlock.set()
task.join(5)
self.assertFalse(task.is_alive())
self.assertIsNone(task.exc)
wait_until(lambda: pool.active_sockets == 0, "return socket")
# Run another operation to ensure the socket still works.
@ -125,7 +138,6 @@ class TestLB(IntegrationTest):
@client_context.require_transactions
def test_session_gc(self):
client = self.rs_client()
self.addCleanup(client.close)
pool = get_pool(client)
session = client.start_session()
session.start_transaction()
@ -137,41 +149,51 @@ class TestLB(IntegrationTest):
if client_context.load_balancer:
self.assertEqual(pool.active_sockets, 1) # Pinned.
thread = PoolLocker(pool)
thread.start()
self.assertTrue(thread.locked.wait(5), "timed out")
task = PoolLocker(pool)
task.start()
self.assertTrue(task.wait(task.locked, 5), "timed out")
# Garbage collect the session while the pool is locked to ensure we
# don't deadlock.
del session
# On PyPy it can take a few rounds to collect the session.
for _ in range(3):
gc.collect()
thread.unlock.set()
thread.join(5)
self.assertFalse(thread.is_alive())
self.assertIsNone(thread.exc)
task.unlock.set()
task.join(5)
self.assertFalse(task.is_alive())
self.assertIsNone(task.exc)
wait_until(lambda: pool.active_sockets == 0, "return socket")
# Run another operation to ensure the socket still works.
client[self.db.name].test.delete_many({})
class PoolLocker(ExceptionCatchingThread):
class PoolLocker(ExceptionCatchingTask):
def __init__(self, pool):
super().__init__(target=self.lock_pool)
self.pool = pool
self.daemon = True
self.locked = threading.Event()
self.unlock = threading.Event()
self.locked = create_event()
self.unlock = create_event()
def lock_pool(self):
with self.pool.lock:
self.locked.set()
# Wait for the unlock flag.
unlock_pool = self.unlock.wait(10)
unlock_pool = self.wait(self.unlock, 10)
if not unlock_pool:
raise Exception("timed out waiting for unlock signal: deadlock?")
def wait(self, event: Event, timeout: int):
if _IS_SYNC:
return event.wait(timeout) # type: ignore[call-arg]
else:
try:
asyncio.wait_for(event.wait(), timeout=timeout)
except asyncio.TimeoutError:
return False
return True
if __name__ == "__main__":
unittest.main()

View File

@ -15,10 +15,13 @@
"""Test the client_session module."""
from __future__ import annotations
import asyncio
import copy
import sys
import time
from asyncio import iscoroutinefunction
from io import BytesIO
from test.helpers import ExceptionCatchingTask
from typing import Any, Callable, List, Set, Tuple
from pymongo.synchronous.mongo_client import MongoClient
@ -35,7 +38,6 @@ from test import (
)
from test.utils import (
EventListener,
ExceptionCatchingThread,
OvertCommandListener,
wait_until,
)
@ -184,7 +186,6 @@ class TestSession(IntegrationTest):
f"{f.__name__} did not return implicit session to pool",
)
@client_context.require_sync
def test_implicit_sessions_checkout(self):
# "To confirm that implicit sessions only allocate their server session after a
# successful connection checkout" test from Driver Sessions Spec.
@ -210,25 +211,26 @@ class TestSession(IntegrationTest):
(cursor.distinct, ["_id"]),
(client.db.list_collections, []),
]
threads = []
tasks = []
listener.reset()
def thread_target(op, *args):
res = op(*args)
def target(op, *args):
if iscoroutinefunction(op):
res = op(*args)
else:
res = op(*args)
if isinstance(res, (Cursor, CommandCursor)):
list(res) # type: ignore[call-overload]
res.to_list()
for op, args in ops:
threads.append(
ExceptionCatchingThread(
target=thread_target, args=[op, *args], name=op.__name__
)
tasks.append(
ExceptionCatchingTask(target=target, args=[op, *args], name=op.__name__)
)
threads[-1].start()
self.assertEqual(len(threads), len(ops))
for thread in threads:
thread.join()
self.assertIsNone(thread.exc)
tasks[-1].start()
self.assertEqual(len(tasks), len(ops))
for t in tasks:
t.join()
self.assertIsNone(t.exc)
client.close()
lsid_set.clear()
for i in listener.started_events:

View File

@ -39,6 +39,7 @@ from bson import json_util
from bson.objectid import ObjectId
from bson.son import SON
from pymongo import AsyncMongoClient, monitoring, operations, read_preferences
from pymongo._asyncio_task import create_task
from pymongo.cursor_shared import CursorType
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.hello import HelloCompat
@ -912,21 +913,6 @@ def is_greenthread_patched():
return gevent_monkey_patched() or eventlet_monkey_patched()
class ExceptionCatchingThread(threading.Thread):
"""A thread that stores any exception encountered from run()."""
def __init__(self, *args, **kwargs):
self.exc = None
super().__init__(*args, **kwargs)
def run(self):
try:
super().run()
except BaseException as exc:
self.exc = exc
raise
def parse_read_preference(pref):
# Make first letter lowercase to match read_pref's modes.
mode_string = pref.get("mode", "primary")
@ -1079,3 +1065,11 @@ async def async_set_fail_point(client, command_args):
cmd = SON([("configureFailPoint", "failCommand")])
cmd.update(command_args)
await client.admin.command(cmd)
def create_async_event():
return asyncio.Event()
def create_event():
return threading.Event()

View File

@ -122,6 +122,7 @@ replacements = {
"SpecRunnerTask": "SpecRunnerThread",
"AsyncMockConnection": "MockConnection",
"AsyncMockPool": "MockPool",
"create_async_event": "create_event",
}
docstring_replacements: dict[tuple[str, str], str] = {
@ -214,6 +215,7 @@ converted_tests = [
"test_heartbeat_monitoring.py",
"test_index_management.py",
"test_grid_file.py",
"test_load_balancer.py",
"test_json_util_integration.py",
"test_gridfs_spec.py",
"test_logger.py",