PYTHON-4782 Fix deadlock and blocking behavior in _ACondition.wait (#1875)

(cherry picked from commit 821811e80d)
This commit is contained in:
Shane Harvey 2024-09-30 16:24:07 -07:00 committed by Steven Silvester
parent d712bc1ec2
commit 6a7fae1cb7
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
14 changed files with 692 additions and 47 deletions

View File

@ -992,7 +992,8 @@ class Pool:
# from the right side.
self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
self.lock = _ALock(_create_lock())
_lock = _create_lock()
self.lock = _ALock(_lock)
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
@ -1018,7 +1019,7 @@ class Pool:
# The first portion of the wait queue.
# Enforces: maxPoolSize
# Also used for: clearing the wait queue
self.size_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type]
self.size_cond = _ACondition(threading.Condition(_lock))
self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size:
@ -1026,7 +1027,7 @@ class Pool:
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type]
self._max_connecting_cond = _ACondition(threading.Condition(_lock))
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._client_id = client_id

View File

@ -170,8 +170,9 @@ class Topology:
self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False
self._closed = False
self._lock = _ALock(_create_lock())
self._condition = _ACondition(self._settings.condition_class(self._lock)) # type: ignore[arg-type]
_lock = _create_lock()
self._lock = _ALock(_lock)
self._condition = _ACondition(self._settings.condition_class(_lock))
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None

View File

@ -14,17 +14,20 @@
from __future__ import annotations
import asyncio
import collections
import os
import threading
import time
import weakref
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, TypeVar
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
# References to instances of _create_lock
_forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()
_T = TypeVar("_T")
def _create_lock() -> threading.Lock:
"""Represents a lock that is tracked upon instantiation using a WeakSet and
@ -43,7 +46,14 @@ def _release_locks() -> None:
lock.release()
# Needed only for synchro.py compat.
def _Lock(lock: threading.Lock) -> threading.Lock:
return lock
class _ALock:
__slots__ = ("_lock",)
def __init__(self, lock: threading.Lock) -> None:
self._lock = lock
@ -81,9 +91,18 @@ class _ALock:
self.release()
def _safe_set_result(fut: asyncio.Future) -> None:
# Ensure the future hasn't been cancelled before calling set_result.
if not fut.done():
fut.set_result(False)
class _ACondition:
__slots__ = ("_condition", "_waiters")
def __init__(self, condition: threading.Condition) -> None:
self._condition = condition
self._waiters: collections.deque = collections.deque()
async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
if timeout > 0:
@ -99,30 +118,116 @@ class _ACondition:
await asyncio.sleep(0)
async def wait(self, timeout: Optional[float] = None) -> bool:
if timeout is not None:
tstart = time.monotonic()
while True:
notified = self._condition.wait(0.001)
if notified:
return True
if timeout is not None and (time.monotonic() - tstart) > timeout:
return False
"""Wait until notified.
async def wait_for(self, predicate: Callable, timeout: Optional[float] = None) -> bool:
if timeout is not None:
tstart = time.monotonic()
while True:
notified = self._condition.wait_for(predicate, 0.001)
if notified:
return True
if timeout is not None and (time.monotonic() - tstart) > timeout:
return False
If the calling task has not acquired the lock when this
method is called, a RuntimeError is raised.
This method releases the underlying lock, and then blocks
until it is awakened by a notify() or notify_all() call for
the same condition variable in another task. Once
awakened, it re-acquires the lock and returns True.
This method may return spuriously,
which is why the caller should always
re-check the state and be prepared to wait() again.
"""
loop = asyncio.get_running_loop()
fut = loop.create_future()
self._waiters.append((loop, fut))
self.release()
try:
try:
try:
await asyncio.wait_for(fut, timeout)
return True
except asyncio.TimeoutError:
return False # Return false on timeout for sync pool compat.
finally:
# Must re-acquire lock even if wait is cancelled.
# We only catch CancelledError here, since we don't want any
# other (fatal) errors with the future to cause us to spin.
err = None
while True:
try:
await self.acquire()
break
except asyncio.exceptions.CancelledError as e:
err = e
self._waiters.remove((loop, fut))
if err is not None:
try:
raise err # Re-raise most recent exception instance.
finally:
err = None # Break reference cycles.
except BaseException:
# Any error raised out of here _may_ have occurred after this Task
# believed to have been successfully notified.
# Make sure to notify another Task instead. This may result
# in a "spurious wakeup", which is allowed as part of the
# Condition Variable protocol.
self.notify(1)
raise
async def wait_for(self, predicate: Callable[[], _T]) -> _T:
"""Wait until a predicate becomes true.
The predicate should be a callable whose result will be
interpreted as a boolean value. The method will repeatedly
wait() until it evaluates to true. The final predicate value is
the return value.
"""
result = predicate()
while not result:
await self.wait()
result = predicate()
return result
def notify(self, n: int = 1) -> None:
self._condition.notify(n)
"""By default, wake up one coroutine waiting on this condition, if any.
If the calling coroutine has not acquired the lock when this method
is called, a RuntimeError is raised.
This method wakes up at most n of the coroutines waiting for the
condition variable; it is a no-op if no coroutines are waiting.
Note: an awakened coroutine does not actually return from its
wait() call until it can reacquire the lock. Since notify() does
not release the lock, its caller should.
"""
idx = 0
to_remove = []
for loop, fut in self._waiters:
if idx >= n:
break
if fut.done():
continue
try:
loop.call_soon_threadsafe(_safe_set_result, fut)
except RuntimeError:
# Loop was closed, ignore.
to_remove.append((loop, fut))
continue
idx += 1
for waiter in to_remove:
self._waiters.remove(waiter)
def notify_all(self) -> None:
self._condition.notify_all()
"""Wake up all threads waiting on this condition. This method acts
like notify(), but wakes up all waiting threads instead of one. If the
calling thread has not acquired the lock when this method is called,
a RuntimeError is raised.
"""
self.notify(len(self._waiters))
def locked(self) -> bool:
"""Only needed for tests in test_locks."""
return self._condition._lock.locked() # type: ignore[attr-defined]
def release(self) -> None:
self._condition.release()

View File

@ -62,7 +62,7 @@ from pymongo.errors import ( # type:ignore[attr-defined]
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.lock import _create_lock
from pymongo.lock import _create_lock, _Lock
from pymongo.logger import (
_CONNECTION_LOGGER,
_ConnectionStatusMessage,
@ -988,7 +988,8 @@ class Pool:
# from the right side.
self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
self.lock = _create_lock()
_lock = _create_lock()
self.lock = _Lock(_lock)
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
@ -1014,7 +1015,7 @@ class Pool:
# The first portion of the wait queue.
# Enforces: maxPoolSize
# Also used for: clearing the wait queue
self.size_cond = threading.Condition(self.lock) # type: ignore[arg-type]
self.size_cond = threading.Condition(_lock)
self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size:
@ -1022,7 +1023,7 @@ class Pool:
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = threading.Condition(self.lock) # type: ignore[arg-type]
self._max_connecting_cond = threading.Condition(_lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._client_id = client_id

View File

@ -39,7 +39,7 @@ from pymongo.errors import (
WriteError,
)
from pymongo.hello import Hello
from pymongo.lock import _create_lock
from pymongo.lock import _create_lock, _Lock
from pymongo.logger import (
_SDAM_LOGGER,
_SERVER_SELECTION_LOGGER,
@ -170,8 +170,9 @@ class Topology:
self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False
self._closed = False
self._lock = _create_lock()
self._condition = self._settings.condition_class(self._lock) # type: ignore[arg-type]
_lock = _create_lock()
self._lock = _Lock(_lock)
self._condition = self._settings.condition_class(_lock)
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None

View File

@ -2351,7 +2351,9 @@ class TestMongoClientFailover(AsyncMockClientTest):
# But it can reconnect.
c.revive_host("a:1")
await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
await (await c._get_topology()).select_servers(
writable_server_selector, _Op.TEST, server_selection_timeout=10
)
self.assertEqual(await c.address, ("a", 1))
async def _test_network_error(self, operation_callback):

View File

@ -30,6 +30,7 @@ from test.utils import (
)
from unittest.mock import patch
import pymongo
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
from pymongo.errors import (
@ -597,7 +598,9 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
timeoutMS=2000,
w="majority",
)
await client.admin.command("ping") # Init the client first.
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(10):
await client.admin.command("ping")
with self.assertRaises(ClientBulkWriteException) as context:
await client.bulk_write(models=models)
self.assertIsInstance(context.exception.error, NetworkTimeout)

View File

@ -1414,7 +1414,7 @@ class TestCursor(AsyncIntegrationTest):
async def test_to_list_csot_applied(self):
client = await self.async_single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
with pymongo.timeout(10):
await client.admin.command("ping")
coll = client.pymongo.test
await coll.insert_many([{} for _ in range(5)])
@ -1456,7 +1456,7 @@ class TestCursor(AsyncIntegrationTest):
async def test_command_cursor_to_list_csot_applied(self):
client = await self.async_single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
with pymongo.timeout(10):
await client.admin.command("ping")
coll = client.pymongo.test
await coll.insert_many([{} for _ in range(5)])

View File

@ -0,0 +1,513 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lock.py"""
from __future__ import annotations
import asyncio
import sys
import threading
import unittest
sys.path[0:0] = [""]
from pymongo.lock import _ACondition
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
# Includes tests for:
# - https://github.com/python/cpython/issues/111693
# - https://github.com/python/cpython/issues/112202
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
async def test_wait(self):
cond = _ACondition(threading.Condition(threading.Lock()))
result = []
async def c1(result):
await cond.acquire()
if await cond.wait():
result.append(1)
return True
async def c2(result):
await cond.acquire()
if await cond.wait():
result.append(2)
return True
async def c3(result):
await cond.acquire()
if await cond.wait():
result.append(3)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))
await asyncio.sleep(0)
self.assertEqual([], result)
self.assertFalse(cond.locked())
self.assertTrue(await cond.acquire())
cond.notify()
await asyncio.sleep(0)
self.assertEqual([], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(cond.locked())
cond.notify(2)
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1, 2], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1, 2, 3], result)
self.assertTrue(cond.locked())
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
self.assertTrue(t3.done())
self.assertTrue(t3.result())
async def test_wait_cancel(self):
cond = _ACondition(threading.Condition(threading.Lock()))
await cond.acquire()
wait = asyncio.create_task(cond.wait())
asyncio.get_running_loop().call_soon(wait.cancel)
with self.assertRaises(asyncio.CancelledError):
await wait
self.assertFalse(cond._waiters)
self.assertTrue(cond.locked())
async def test_wait_cancel_contested(self):
cond = _ACondition(threading.Condition(threading.Lock()))
await cond.acquire()
self.assertTrue(cond.locked())
wait_task = asyncio.create_task(cond.wait())
await asyncio.sleep(0)
self.assertFalse(cond.locked())
# Notify, but contest the lock before cancelling
await cond.acquire()
self.assertTrue(cond.locked())
cond.notify()
asyncio.get_running_loop().call_soon(wait_task.cancel)
asyncio.get_running_loop().call_soon(cond.release)
try:
await wait_task
except asyncio.CancelledError:
# Should not happen, since no cancellation points
pass
self.assertTrue(cond.locked())
async def test_wait_cancel_after_notify(self):
# See bpo-32841
waited = False
cond = _ACondition(threading.Condition(threading.Lock()))
async def wait_on_cond():
nonlocal waited
async with cond:
waited = True # Make sure this area was reached
await cond.wait()
waiter = asyncio.create_task(wait_on_cond())
await asyncio.sleep(0) # Start waiting
await cond.acquire()
cond.notify()
await asyncio.sleep(0) # Get to acquire()
waiter.cancel()
await asyncio.sleep(0) # Activate cancellation
cond.release()
await asyncio.sleep(0) # Cancellation should occur
self.assertTrue(waiter.cancelled())
self.assertTrue(waited)
async def test_wait_unacquired(self):
cond = _ACondition(threading.Condition(threading.Lock()))
with self.assertRaises(RuntimeError):
await cond.wait()
async def test_wait_for(self):
cond = _ACondition(threading.Condition(threading.Lock()))
presult = False
def predicate():
return presult
result = []
async def c1(result):
await cond.acquire()
if await cond.wait_for(predicate):
result.append(1)
cond.release()
return True
t = asyncio.create_task(c1(result))
await asyncio.sleep(0)
self.assertEqual([], result)
await cond.acquire()
cond.notify()
cond.release()
await asyncio.sleep(0)
self.assertEqual([], result)
presult = True
await cond.acquire()
cond.notify()
cond.release()
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(t.done())
self.assertTrue(t.result())
async def test_wait_for_unacquired(self):
cond = _ACondition(threading.Condition(threading.Lock()))
# predicate can return true immediately
res = await cond.wait_for(lambda: [1, 2, 3])
self.assertEqual([1, 2, 3], res)
with self.assertRaises(RuntimeError):
await cond.wait_for(lambda: False)
async def test_notify(self):
cond = _ACondition(threading.Condition(threading.Lock()))
result = []
async def c1(result):
async with cond:
if await cond.wait():
result.append(1)
return True
async def c2(result):
async with cond:
if await cond.wait():
result.append(2)
return True
async def c3(result):
async with cond:
if await cond.wait():
result.append(3)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))
await asyncio.sleep(0)
self.assertEqual([], result)
async with cond:
cond.notify(1)
await asyncio.sleep(1)
self.assertEqual([1], result)
async with cond:
cond.notify(1)
cond.notify(2048)
await asyncio.sleep(1)
self.assertEqual([1, 2, 3], result)
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
self.assertTrue(t3.done())
self.assertTrue(t3.result())
async def test_notify_all(self):
cond = _ACondition(threading.Condition(threading.Lock()))
result = []
async def c1(result):
async with cond:
if await cond.wait():
result.append(1)
return True
async def c2(result):
async with cond:
if await cond.wait():
result.append(2)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
await asyncio.sleep(0)
self.assertEqual([], result)
async with cond:
cond.notify_all()
await asyncio.sleep(1)
self.assertEqual([1, 2], result)
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
async def test_context_manager(self):
cond = _ACondition(threading.Condition(threading.Lock()))
self.assertFalse(cond.locked())
async with cond:
self.assertTrue(cond.locked())
self.assertFalse(cond.locked())
async def test_timeout_in_block(self):
condition = _ACondition(threading.Condition(threading.Lock()))
async with condition:
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(condition.wait(), timeout=0.5)
@unittest.skipIf(
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
)
async def test_cancelled_error_wakeup(self):
# Test that a cancelled error, received when awaiting wakeup,
# will be re-raised un-modified.
wake = False
raised = None
cond = _ACondition(threading.Condition(threading.Lock()))
async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised
task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition, cancel it there.
task.cancel(msg="foo") # type: ignore[call-arg]
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)
@unittest.skipIf(
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
)
async def test_cancelled_error_re_aquire(self):
# Test that a cancelled error, received when re-aquiring lock,
# will be re-raised un-modified.
wake = False
raised = None
cond = _ACondition(threading.Condition(threading.Lock()))
async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised
task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition
await cond.acquire()
wake = True
cond.notify()
await asyncio.sleep(0)
# Task is now trying to re-acquire the lock, cancel it there.
task.cancel(msg="foo") # type: ignore[call-arg]
cond.release()
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
async def test_cancelled_wakeup(self):
# Test that a task cancelled at the "same" time as it is woken
# up as part of a Condition.notify() does not result in a lost wakeup.
# This test simulates a cancel while the target task is awaiting initial
# wakeup on the wakeup queue.
condition = _ACondition(threading.Condition(threading.Lock()))
state = 0
async def consumer():
nonlocal state
async with condition:
while True:
await condition.wait_for(lambda: state != 0)
if state < 0:
return
state -= 1
# create two consumers
c = [asyncio.create_task(consumer()) for _ in range(2)]
# wait for them to settle
await asyncio.sleep(0.1)
async with condition:
# produce one item and wake up one
state += 1
condition.notify(1)
# Cancel it while it is awaiting to be run.
# This cancellation could come from the outside
c[0].cancel()
# now wait for the item to be consumed
# if it doesn't means that our "notify" didn"t take hold.
# because it raced with a cancel()
try:
async with asyncio.timeout(1):
await condition.wait_for(lambda: state == 0)
except TimeoutError:
pass
self.assertEqual(state, 0)
# clean up
state = -1
condition.notify_all()
await c[1]
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
async def test_cancelled_wakeup_relock(self):
# Test that a task cancelled at the "same" time as it is woken
# up as part of a Condition.notify() does not result in a lost wakeup.
# This test simulates a cancel while the target task is acquiring the lock
# again.
condition = _ACondition(threading.Condition(threading.Lock()))
state = 0
async def consumer():
nonlocal state
async with condition:
while True:
await condition.wait_for(lambda: state != 0)
if state < 0:
return
state -= 1
# create two consumers
c = [asyncio.create_task(consumer()) for _ in range(2)]
# wait for them to settle
await asyncio.sleep(0.1)
async with condition:
# produce one item and wake up one
state += 1
condition.notify(1)
# now we sleep for a bit. This allows the target task to wake up and
# settle on re-aquiring the lock
await asyncio.sleep(0)
# Cancel it while awaiting the lock
# This cancel could come the outside.
c[0].cancel()
# now wait for the item to be consumed
# if it doesn't means that our "notify" didn"t take hold.
# because it raced with a cancel()
try:
async with asyncio.timeout(1):
await condition.wait_for(lambda: state == 0)
except TimeoutError:
pass
self.assertEqual(state, 0)
# clean up
state = -1
condition.notify_all()
await c[1]
class TestCondition(unittest.IsolatedAsyncioTestCase):
async def test_multiple_loops_notify(self):
cond = _ACondition(threading.Condition(threading.Lock()))
def tmain(cond):
async def atmain(cond):
await asyncio.sleep(1)
async with cond:
cond.notify(1)
asyncio.run(atmain(cond))
t = threading.Thread(target=tmain, args=(cond,))
t.start()
async with cond:
self.assertTrue(await cond.wait(30))
t.join()
async def test_multiple_loops_notify_all(self):
cond = _ACondition(threading.Condition(threading.Lock()))
results = []
def tmain(cond, results):
async def atmain(cond, results):
await asyncio.sleep(1)
async with cond:
res = await cond.wait(30)
results.append(res)
asyncio.run(atmain(cond, results))
nthreads = 5
threads = []
for _ in range(nthreads):
threads.append(threading.Thread(target=tmain, args=(cond, results)))
for t in threads:
t.start()
await asyncio.sleep(2)
async with cond:
cond.notify_all()
for t in threads:
t.join()
self.assertEqual(results, [True] * nthreads)
if __name__ == "__main__":
unittest.main()

View File

@ -2307,7 +2307,9 @@ class TestMongoClientFailover(MockClientTest):
# But it can reconnect.
c.revive_host("a:1")
(c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
(c._get_topology()).select_servers(
writable_server_selector, _Op.TEST, server_selection_timeout=10
)
self.assertEqual(c.address, ("a", 1))
def _test_network_error(self, operation_callback):

View File

@ -30,6 +30,7 @@ from test.utils import (
)
from unittest.mock import patch
import pymongo
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
from pymongo.errors import (
ClientBulkWriteException,
@ -597,7 +598,9 @@ class TestClientBulkWriteCSOT(IntegrationTest):
timeoutMS=2000,
w="majority",
)
client.admin.command("ping") # Init the client first.
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(10):
client.admin.command("ping")
with self.assertRaises(ClientBulkWriteException) as context:
client.bulk_write(models=models)
self.assertIsInstance(context.exception.error, NetworkTimeout)

View File

@ -1405,7 +1405,7 @@ class TestCursor(IntegrationTest):
def test_to_list_csot_applied(self):
client = self.single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
with pymongo.timeout(10):
client.admin.command("ping")
coll = client.pymongo.test
coll.insert_many([{} for _ in range(5)])
@ -1447,7 +1447,7 @@ class TestCursor(IntegrationTest):
def test_command_cursor_to_list_csot_applied(self):
client = self.single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
with pymongo.timeout(10):
client.admin.command("ping")
coll = client.pymongo.test
coll.insert_many([{} for _ in range(5)])

View File

@ -19,6 +19,7 @@ import os
import threading
from test import IntegrationTest, client_context, unittest
from test.utils import (
CMAPListener,
OvertCommandListener,
SpecTestCreator,
get_pool,
@ -27,6 +28,7 @@ from test.utils import (
from test.utils_selection_tests import create_topology
from pymongo.common import clean_node
from pymongo.monitoring import ConnectionReadyEvent
from pymongo.operations import _Op
from pymongo.read_preferences import ReadPreference
@ -131,19 +133,20 @@ class TestProse(IntegrationTest):
@client_context.require_multiple_mongoses
def test_load_balancing(self):
listener = OvertCommandListener()
cmap_listener = CMAPListener()
# PYTHON-2584: Use a large localThresholdMS to avoid the impact of
# varying RTTs.
client = self.rs_client(
client_context.mongos_seeds(),
appName="loadBalancingTest",
event_listeners=[listener],
event_listeners=[listener, cmap_listener],
localThresholdMS=30000,
minPoolSize=10,
)
self.addCleanup(client.close)
wait_until(lambda: len(client.nodes) == 2, "discover both nodes")
wait_until(lambda: len(get_pool(client).conns) >= 10, "create 10 connections")
# Delay find commands on
# Wait for both pools to be populated.
cmap_listener.wait_for_event(ConnectionReadyEvent, 20)
# Delay find commands on only one mongos.
delay_finds = {
"configureFailPoint": "failCommand",
"mode": {"times": 10000},
@ -161,7 +164,7 @@ class TestProse(IntegrationTest):
freqs = self.frequencies(client, listener)
self.assertLessEqual(freqs[delayed_server], 0.25)
listener.reset()
freqs = self.frequencies(client, listener, n_finds=100)
freqs = self.frequencies(client, listener, n_finds=150)
self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15)

View File

@ -144,7 +144,17 @@ gridfs_files = [
_gridfs_base + f for f in listdir(_gridfs_base) if (Path(_gridfs_base) / f).is_file()
]
test_files = [_test_base + f for f in listdir(_test_base) if (Path(_test_base) / f).is_file()]
def async_only_test(f: str) -> bool:
"""Return True for async tests that should not be converted to sync."""
return f in ["test_locks.py"]
test_files = [
_test_base + f
for f in listdir(_test_base)
if (Path(_test_base) / f).is_file() and not async_only_test(f)
]
sync_files = [
_pymongo_dest_base + f
@ -240,7 +250,7 @@ def translate_locks(lines: list[str]) -> list[str]:
lock_lines = [line for line in lines if "_Lock(" in line]
cond_lines = [line for line in lines if "_Condition(" in line]
for line in lock_lines:
res = re.search(r"_Lock\(([^()]*\(\))\)", line)
res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line)
if res:
old = res[0]
index = lines.index(line)