PYTHON-4782 Fix deadlock and blocking behavior in _ACondition.wait (#1875)
This commit is contained in:
parent
083359f95f
commit
821811e80d
@ -992,7 +992,8 @@ class Pool:
|
||||
# from the right side.
|
||||
self.conns: collections.deque = collections.deque()
|
||||
self.active_contexts: set[_CancellationContext] = set()
|
||||
self.lock = _ALock(_create_lock())
|
||||
_lock = _create_lock()
|
||||
self.lock = _ALock(_lock)
|
||||
self.active_sockets = 0
|
||||
# Monotonically increasing connection ID required for CMAP Events.
|
||||
self.next_connection_id = 1
|
||||
@ -1018,7 +1019,7 @@ class Pool:
|
||||
# The first portion of the wait queue.
|
||||
# Enforces: maxPoolSize
|
||||
# Also used for: clearing the wait queue
|
||||
self.size_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type]
|
||||
self.size_cond = _ACondition(threading.Condition(_lock))
|
||||
self.requests = 0
|
||||
self.max_pool_size = self.opts.max_pool_size
|
||||
if not self.max_pool_size:
|
||||
@ -1026,7 +1027,7 @@ class Pool:
|
||||
# The second portion of the wait queue.
|
||||
# Enforces: maxConnecting
|
||||
# Also used for: clearing the wait queue
|
||||
self._max_connecting_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type]
|
||||
self._max_connecting_cond = _ACondition(threading.Condition(_lock))
|
||||
self._max_connecting = self.opts.max_connecting
|
||||
self._pending = 0
|
||||
self._client_id = client_id
|
||||
|
||||
@ -170,8 +170,9 @@ class Topology:
|
||||
self._seed_addresses = list(topology_description.server_descriptions())
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
self._lock = _ALock(_create_lock())
|
||||
self._condition = _ACondition(self._settings.condition_class(self._lock)) # type: ignore[arg-type]
|
||||
_lock = _create_lock()
|
||||
self._lock = _ALock(_lock)
|
||||
self._condition = _ACondition(self._settings.condition_class(_lock))
|
||||
self._servers: dict[_Address, Server] = {}
|
||||
self._pid: Optional[int] = None
|
||||
self._max_cluster_time: Optional[ClusterTime] = None
|
||||
|
||||
145
pymongo/lock.py
145
pymongo/lock.py
@ -14,17 +14,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
|
||||
|
||||
# References to instances of _create_lock
|
||||
_forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _create_lock() -> threading.Lock:
|
||||
"""Represents a lock that is tracked upon instantiation using a WeakSet and
|
||||
@ -43,7 +46,14 @@ def _release_locks() -> None:
|
||||
lock.release()
|
||||
|
||||
|
||||
# Needed only for synchro.py compat.
|
||||
def _Lock(lock: threading.Lock) -> threading.Lock:
|
||||
return lock
|
||||
|
||||
|
||||
class _ALock:
|
||||
__slots__ = ("_lock",)
|
||||
|
||||
def __init__(self, lock: threading.Lock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
@ -81,9 +91,18 @@ class _ALock:
|
||||
self.release()
|
||||
|
||||
|
||||
def _safe_set_result(fut: asyncio.Future) -> None:
|
||||
# Ensure the future hasn't been cancelled before calling set_result.
|
||||
if not fut.done():
|
||||
fut.set_result(False)
|
||||
|
||||
|
||||
class _ACondition:
|
||||
__slots__ = ("_condition", "_waiters")
|
||||
|
||||
def __init__(self, condition: threading.Condition) -> None:
|
||||
self._condition = condition
|
||||
self._waiters: collections.deque = collections.deque()
|
||||
|
||||
async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
|
||||
if timeout > 0:
|
||||
@ -99,30 +118,116 @@ class _ACondition:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def wait(self, timeout: Optional[float] = None) -> bool:
|
||||
if timeout is not None:
|
||||
tstart = time.monotonic()
|
||||
while True:
|
||||
notified = self._condition.wait(0.001)
|
||||
if notified:
|
||||
return True
|
||||
if timeout is not None and (time.monotonic() - tstart) > timeout:
|
||||
return False
|
||||
"""Wait until notified.
|
||||
|
||||
async def wait_for(self, predicate: Callable, timeout: Optional[float] = None) -> bool:
|
||||
if timeout is not None:
|
||||
tstart = time.monotonic()
|
||||
while True:
|
||||
notified = self._condition.wait_for(predicate, 0.001)
|
||||
if notified:
|
||||
return True
|
||||
if timeout is not None and (time.monotonic() - tstart) > timeout:
|
||||
return False
|
||||
If the calling task has not acquired the lock when this
|
||||
method is called, a RuntimeError is raised.
|
||||
|
||||
This method releases the underlying lock, and then blocks
|
||||
until it is awakened by a notify() or notify_all() call for
|
||||
the same condition variable in another task. Once
|
||||
awakened, it re-acquires the lock and returns True.
|
||||
|
||||
This method may return spuriously,
|
||||
which is why the caller should always
|
||||
re-check the state and be prepared to wait() again.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
fut = loop.create_future()
|
||||
self._waiters.append((loop, fut))
|
||||
self.release()
|
||||
try:
|
||||
try:
|
||||
try:
|
||||
await asyncio.wait_for(fut, timeout)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
return False # Return false on timeout for sync pool compat.
|
||||
finally:
|
||||
# Must re-acquire lock even if wait is cancelled.
|
||||
# We only catch CancelledError here, since we don't want any
|
||||
# other (fatal) errors with the future to cause us to spin.
|
||||
err = None
|
||||
while True:
|
||||
try:
|
||||
await self.acquire()
|
||||
break
|
||||
except asyncio.exceptions.CancelledError as e:
|
||||
err = e
|
||||
|
||||
self._waiters.remove((loop, fut))
|
||||
if err is not None:
|
||||
try:
|
||||
raise err # Re-raise most recent exception instance.
|
||||
finally:
|
||||
err = None # Break reference cycles.
|
||||
except BaseException:
|
||||
# Any error raised out of here _may_ have occurred after this Task
|
||||
# believed to have been successfully notified.
|
||||
# Make sure to notify another Task instead. This may result
|
||||
# in a "spurious wakeup", which is allowed as part of the
|
||||
# Condition Variable protocol.
|
||||
self.notify(1)
|
||||
raise
|
||||
|
||||
async def wait_for(self, predicate: Callable[[], _T]) -> _T:
|
||||
"""Wait until a predicate becomes true.
|
||||
|
||||
The predicate should be a callable whose result will be
|
||||
interpreted as a boolean value. The method will repeatedly
|
||||
wait() until it evaluates to true. The final predicate value is
|
||||
the return value.
|
||||
"""
|
||||
result = predicate()
|
||||
while not result:
|
||||
await self.wait()
|
||||
result = predicate()
|
||||
return result
|
||||
|
||||
def notify(self, n: int = 1) -> None:
|
||||
self._condition.notify(n)
|
||||
"""By default, wake up one coroutine waiting on this condition, if any.
|
||||
If the calling coroutine has not acquired the lock when this method
|
||||
is called, a RuntimeError is raised.
|
||||
|
||||
This method wakes up at most n of the coroutines waiting for the
|
||||
condition variable; it is a no-op if no coroutines are waiting.
|
||||
|
||||
Note: an awakened coroutine does not actually return from its
|
||||
wait() call until it can reacquire the lock. Since notify() does
|
||||
not release the lock, its caller should.
|
||||
"""
|
||||
idx = 0
|
||||
to_remove = []
|
||||
for loop, fut in self._waiters:
|
||||
if idx >= n:
|
||||
break
|
||||
|
||||
if fut.done():
|
||||
continue
|
||||
|
||||
try:
|
||||
loop.call_soon_threadsafe(_safe_set_result, fut)
|
||||
except RuntimeError:
|
||||
# Loop was closed, ignore.
|
||||
to_remove.append((loop, fut))
|
||||
continue
|
||||
|
||||
idx += 1
|
||||
|
||||
for waiter in to_remove:
|
||||
self._waiters.remove(waiter)
|
||||
|
||||
def notify_all(self) -> None:
|
||||
self._condition.notify_all()
|
||||
"""Wake up all threads waiting on this condition. This method acts
|
||||
like notify(), but wakes up all waiting threads instead of one. If the
|
||||
calling thread has not acquired the lock when this method is called,
|
||||
a RuntimeError is raised.
|
||||
"""
|
||||
self.notify(len(self._waiters))
|
||||
|
||||
def locked(self) -> bool:
|
||||
"""Only needed for tests in test_locks."""
|
||||
return self._condition._lock.locked() # type: ignore[attr-defined]
|
||||
|
||||
def release(self) -> None:
|
||||
self._condition.release()
|
||||
|
||||
@ -62,7 +62,7 @@ from pymongo.errors import ( # type:ignore[attr-defined]
|
||||
_CertificateError,
|
||||
)
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.lock import _create_lock, _Lock
|
||||
from pymongo.logger import (
|
||||
_CONNECTION_LOGGER,
|
||||
_ConnectionStatusMessage,
|
||||
@ -988,7 +988,8 @@ class Pool:
|
||||
# from the right side.
|
||||
self.conns: collections.deque = collections.deque()
|
||||
self.active_contexts: set[_CancellationContext] = set()
|
||||
self.lock = _create_lock()
|
||||
_lock = _create_lock()
|
||||
self.lock = _Lock(_lock)
|
||||
self.active_sockets = 0
|
||||
# Monotonically increasing connection ID required for CMAP Events.
|
||||
self.next_connection_id = 1
|
||||
@ -1014,7 +1015,7 @@ class Pool:
|
||||
# The first portion of the wait queue.
|
||||
# Enforces: maxPoolSize
|
||||
# Also used for: clearing the wait queue
|
||||
self.size_cond = threading.Condition(self.lock) # type: ignore[arg-type]
|
||||
self.size_cond = threading.Condition(_lock)
|
||||
self.requests = 0
|
||||
self.max_pool_size = self.opts.max_pool_size
|
||||
if not self.max_pool_size:
|
||||
@ -1022,7 +1023,7 @@ class Pool:
|
||||
# The second portion of the wait queue.
|
||||
# Enforces: maxConnecting
|
||||
# Also used for: clearing the wait queue
|
||||
self._max_connecting_cond = threading.Condition(self.lock) # type: ignore[arg-type]
|
||||
self._max_connecting_cond = threading.Condition(_lock)
|
||||
self._max_connecting = self.opts.max_connecting
|
||||
self._pending = 0
|
||||
self._client_id = client_id
|
||||
|
||||
@ -39,7 +39,7 @@ from pymongo.errors import (
|
||||
WriteError,
|
||||
)
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.lock import _create_lock, _Lock
|
||||
from pymongo.logger import (
|
||||
_SDAM_LOGGER,
|
||||
_SERVER_SELECTION_LOGGER,
|
||||
@ -170,8 +170,9 @@ class Topology:
|
||||
self._seed_addresses = list(topology_description.server_descriptions())
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
self._lock = _create_lock()
|
||||
self._condition = self._settings.condition_class(self._lock) # type: ignore[arg-type]
|
||||
_lock = _create_lock()
|
||||
self._lock = _Lock(_lock)
|
||||
self._condition = self._settings.condition_class(_lock)
|
||||
self._servers: dict[_Address, Server] = {}
|
||||
self._pid: Optional[int] = None
|
||||
self._max_cluster_time: Optional[ClusterTime] = None
|
||||
|
||||
@ -2433,7 +2433,9 @@ class TestMongoClientFailover(AsyncMockClientTest):
|
||||
|
||||
# But it can reconnect.
|
||||
c.revive_host("a:1")
|
||||
await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
|
||||
await (await c._get_topology()).select_servers(
|
||||
writable_server_selector, _Op.TEST, server_selection_timeout=10
|
||||
)
|
||||
self.assertEqual(await c.address, ("a", 1))
|
||||
|
||||
async def _test_network_error(self, operation_callback):
|
||||
|
||||
@ -30,6 +30,7 @@ from test.utils import (
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
import pymongo
|
||||
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
|
||||
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
|
||||
from pymongo.errors import (
|
||||
@ -597,7 +598,9 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
|
||||
timeoutMS=2000,
|
||||
w="majority",
|
||||
)
|
||||
await client.admin.command("ping") # Init the client first.
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(10):
|
||||
await client.admin.command("ping")
|
||||
with self.assertRaises(ClientBulkWriteException) as context:
|
||||
await client.bulk_write(models=models)
|
||||
self.assertIsInstance(context.exception.error, NetworkTimeout)
|
||||
|
||||
@ -1414,7 +1414,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
async def test_to_list_csot_applied(self):
|
||||
client = await self.async_single_client(timeoutMS=500)
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(2):
|
||||
with pymongo.timeout(10):
|
||||
await client.admin.command("ping")
|
||||
coll = client.pymongo.test
|
||||
await coll.insert_many([{} for _ in range(5)])
|
||||
@ -1456,7 +1456,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
async def test_command_cursor_to_list_csot_applied(self):
|
||||
client = await self.async_single_client(timeoutMS=500)
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(2):
|
||||
with pymongo.timeout(10):
|
||||
await client.admin.command("ping")
|
||||
coll = client.pymongo.test
|
||||
await coll.insert_many([{} for _ in range(5)])
|
||||
|
||||
513
test/asynchronous/test_locks.py
Normal file
513
test/asynchronous/test_locks.py
Normal file
@ -0,0 +1,513 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for lock.py"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import threading
|
||||
import unittest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from pymongo.lock import _ACondition
|
||||
|
||||
|
||||
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
|
||||
# Includes tests for:
|
||||
# - https://github.com/python/cpython/issues/111693
|
||||
# - https://github.com/python/cpython/issues/112202
|
||||
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_wait(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait():
|
||||
result.append(1)
|
||||
return True
|
||||
|
||||
async def c2(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait():
|
||||
result.append(2)
|
||||
return True
|
||||
|
||||
async def c3(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait():
|
||||
result.append(3)
|
||||
return True
|
||||
|
||||
t1 = asyncio.create_task(c1(result))
|
||||
t2 = asyncio.create_task(c2(result))
|
||||
t3 = asyncio.create_task(c3(result))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
self.assertFalse(cond.locked())
|
||||
|
||||
self.assertTrue(await cond.acquire())
|
||||
cond.notify()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.notify(2)
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1, 2], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1, 2, 3], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
self.assertTrue(t1.done())
|
||||
self.assertTrue(t1.result())
|
||||
self.assertTrue(t2.done())
|
||||
self.assertTrue(t2.result())
|
||||
self.assertTrue(t3.done())
|
||||
self.assertTrue(t3.result())
|
||||
|
||||
async def test_wait_cancel(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
await cond.acquire()
|
||||
|
||||
wait = asyncio.create_task(cond.wait())
|
||||
asyncio.get_running_loop().call_soon(wait.cancel)
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await wait
|
||||
self.assertFalse(cond._waiters)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
async def test_wait_cancel_contested(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
await cond.acquire()
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
wait_task = asyncio.create_task(cond.wait())
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(cond.locked())
|
||||
|
||||
# Notify, but contest the lock before cancelling
|
||||
await cond.acquire()
|
||||
self.assertTrue(cond.locked())
|
||||
cond.notify()
|
||||
asyncio.get_running_loop().call_soon(wait_task.cancel)
|
||||
asyncio.get_running_loop().call_soon(cond.release)
|
||||
|
||||
try:
|
||||
await wait_task
|
||||
except asyncio.CancelledError:
|
||||
# Should not happen, since no cancellation points
|
||||
pass
|
||||
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
async def test_wait_cancel_after_notify(self):
|
||||
# See bpo-32841
|
||||
waited = False
|
||||
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
async def wait_on_cond():
|
||||
nonlocal waited
|
||||
async with cond:
|
||||
waited = True # Make sure this area was reached
|
||||
await cond.wait()
|
||||
|
||||
waiter = asyncio.create_task(wait_on_cond())
|
||||
await asyncio.sleep(0) # Start waiting
|
||||
|
||||
await cond.acquire()
|
||||
cond.notify()
|
||||
await asyncio.sleep(0) # Get to acquire()
|
||||
waiter.cancel()
|
||||
await asyncio.sleep(0) # Activate cancellation
|
||||
cond.release()
|
||||
await asyncio.sleep(0) # Cancellation should occur
|
||||
|
||||
self.assertTrue(waiter.cancelled())
|
||||
self.assertTrue(waited)
|
||||
|
||||
async def test_wait_unacquired(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
with self.assertRaises(RuntimeError):
|
||||
await cond.wait()
|
||||
|
||||
async def test_wait_for(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
presult = False
|
||||
|
||||
def predicate():
|
||||
return presult
|
||||
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait_for(predicate):
|
||||
result.append(1)
|
||||
cond.release()
|
||||
return True
|
||||
|
||||
t = asyncio.create_task(c1(result))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
await cond.acquire()
|
||||
cond.notify()
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
presult = True
|
||||
await cond.acquire()
|
||||
cond.notify()
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1], result)
|
||||
|
||||
self.assertTrue(t.done())
|
||||
self.assertTrue(t.result())
|
||||
|
||||
async def test_wait_for_unacquired(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
# predicate can return true immediately
|
||||
res = await cond.wait_for(lambda: [1, 2, 3])
|
||||
self.assertEqual([1, 2, 3], res)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
await cond.wait_for(lambda: False)
|
||||
|
||||
async def test_notify(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(1)
|
||||
return True
|
||||
|
||||
async def c2(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(2)
|
||||
return True
|
||||
|
||||
async def c3(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(3)
|
||||
return True
|
||||
|
||||
t1 = asyncio.create_task(c1(result))
|
||||
t2 = asyncio.create_task(c2(result))
|
||||
t3 = asyncio.create_task(c3(result))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
async with cond:
|
||||
cond.notify(1)
|
||||
await asyncio.sleep(1)
|
||||
self.assertEqual([1], result)
|
||||
|
||||
async with cond:
|
||||
cond.notify(1)
|
||||
cond.notify(2048)
|
||||
await asyncio.sleep(1)
|
||||
self.assertEqual([1, 2, 3], result)
|
||||
|
||||
self.assertTrue(t1.done())
|
||||
self.assertTrue(t1.result())
|
||||
self.assertTrue(t2.done())
|
||||
self.assertTrue(t2.result())
|
||||
self.assertTrue(t3.done())
|
||||
self.assertTrue(t3.result())
|
||||
|
||||
async def test_notify_all(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(1)
|
||||
return True
|
||||
|
||||
async def c2(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(2)
|
||||
return True
|
||||
|
||||
t1 = asyncio.create_task(c1(result))
|
||||
t2 = asyncio.create_task(c2(result))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
async with cond:
|
||||
cond.notify_all()
|
||||
await asyncio.sleep(1)
|
||||
self.assertEqual([1, 2], result)
|
||||
|
||||
self.assertTrue(t1.done())
|
||||
self.assertTrue(t1.result())
|
||||
self.assertTrue(t2.done())
|
||||
self.assertTrue(t2.result())
|
||||
|
||||
async def test_context_manager(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
self.assertFalse(cond.locked())
|
||||
async with cond:
|
||||
self.assertTrue(cond.locked())
|
||||
self.assertFalse(cond.locked())
|
||||
|
||||
async def test_timeout_in_block(self):
|
||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
||||
async with condition:
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(condition.wait(), timeout=0.5)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
|
||||
)
|
||||
async def test_cancelled_error_wakeup(self):
|
||||
# Test that a cancelled error, received when awaiting wakeup,
|
||||
# will be re-raised un-modified.
|
||||
wake = False
|
||||
raised = None
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
async def func():
|
||||
nonlocal raised
|
||||
async with cond:
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await cond.wait_for(lambda: wake)
|
||||
raised = err.exception
|
||||
raise raised
|
||||
|
||||
task = asyncio.create_task(func())
|
||||
await asyncio.sleep(0)
|
||||
# Task is waiting on the condition, cancel it there.
|
||||
task.cancel(msg="foo") # type: ignore[call-arg]
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await task
|
||||
self.assertEqual(err.exception.args, ("foo",))
|
||||
# We should have got the _same_ exception instance as the one
|
||||
# originally raised.
|
||||
self.assertIs(err.exception, raised)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
|
||||
)
|
||||
async def test_cancelled_error_re_aquire(self):
|
||||
# Test that a cancelled error, received when re-aquiring lock,
|
||||
# will be re-raised un-modified.
|
||||
wake = False
|
||||
raised = None
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
async def func():
|
||||
nonlocal raised
|
||||
async with cond:
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await cond.wait_for(lambda: wake)
|
||||
raised = err.exception
|
||||
raise raised
|
||||
|
||||
task = asyncio.create_task(func())
|
||||
await asyncio.sleep(0)
|
||||
# Task is waiting on the condition
|
||||
await cond.acquire()
|
||||
wake = True
|
||||
cond.notify()
|
||||
await asyncio.sleep(0)
|
||||
# Task is now trying to re-acquire the lock, cancel it there.
|
||||
task.cancel(msg="foo") # type: ignore[call-arg]
|
||||
cond.release()
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await task
|
||||
self.assertEqual(err.exception.args, ("foo",))
|
||||
# We should have got the _same_ exception instance as the one
|
||||
# originally raised.
|
||||
self.assertIs(err.exception, raised)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
|
||||
async def test_cancelled_wakeup(self):
|
||||
# Test that a task cancelled at the "same" time as it is woken
|
||||
# up as part of a Condition.notify() does not result in a lost wakeup.
|
||||
# This test simulates a cancel while the target task is awaiting initial
|
||||
# wakeup on the wakeup queue.
|
||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
||||
state = 0
|
||||
|
||||
async def consumer():
|
||||
nonlocal state
|
||||
async with condition:
|
||||
while True:
|
||||
await condition.wait_for(lambda: state != 0)
|
||||
if state < 0:
|
||||
return
|
||||
state -= 1
|
||||
|
||||
# create two consumers
|
||||
c = [asyncio.create_task(consumer()) for _ in range(2)]
|
||||
# wait for them to settle
|
||||
await asyncio.sleep(0.1)
|
||||
async with condition:
|
||||
# produce one item and wake up one
|
||||
state += 1
|
||||
condition.notify(1)
|
||||
|
||||
# Cancel it while it is awaiting to be run.
|
||||
# This cancellation could come from the outside
|
||||
c[0].cancel()
|
||||
|
||||
# now wait for the item to be consumed
|
||||
# if it doesn't means that our "notify" didn"t take hold.
|
||||
# because it raced with a cancel()
|
||||
try:
|
||||
async with asyncio.timeout(1):
|
||||
await condition.wait_for(lambda: state == 0)
|
||||
except TimeoutError:
|
||||
pass
|
||||
self.assertEqual(state, 0)
|
||||
|
||||
# clean up
|
||||
state = -1
|
||||
condition.notify_all()
|
||||
await c[1]
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
|
||||
async def test_cancelled_wakeup_relock(self):
|
||||
# Test that a task cancelled at the "same" time as it is woken
|
||||
# up as part of a Condition.notify() does not result in a lost wakeup.
|
||||
# This test simulates a cancel while the target task is acquiring the lock
|
||||
# again.
|
||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
||||
state = 0
|
||||
|
||||
async def consumer():
|
||||
nonlocal state
|
||||
async with condition:
|
||||
while True:
|
||||
await condition.wait_for(lambda: state != 0)
|
||||
if state < 0:
|
||||
return
|
||||
state -= 1
|
||||
|
||||
# create two consumers
|
||||
c = [asyncio.create_task(consumer()) for _ in range(2)]
|
||||
# wait for them to settle
|
||||
await asyncio.sleep(0.1)
|
||||
async with condition:
|
||||
# produce one item and wake up one
|
||||
state += 1
|
||||
condition.notify(1)
|
||||
|
||||
# now we sleep for a bit. This allows the target task to wake up and
|
||||
# settle on re-aquiring the lock
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Cancel it while awaiting the lock
|
||||
# This cancel could come the outside.
|
||||
c[0].cancel()
|
||||
|
||||
# now wait for the item to be consumed
|
||||
# if it doesn't means that our "notify" didn"t take hold.
|
||||
# because it raced with a cancel()
|
||||
try:
|
||||
async with asyncio.timeout(1):
|
||||
await condition.wait_for(lambda: state == 0)
|
||||
except TimeoutError:
|
||||
pass
|
||||
self.assertEqual(state, 0)
|
||||
|
||||
# clean up
|
||||
state = -1
|
||||
condition.notify_all()
|
||||
await c[1]
|
||||
|
||||
|
||||
class TestCondition(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_multiple_loops_notify(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
def tmain(cond):
|
||||
async def atmain(cond):
|
||||
await asyncio.sleep(1)
|
||||
async with cond:
|
||||
cond.notify(1)
|
||||
|
||||
asyncio.run(atmain(cond))
|
||||
|
||||
t = threading.Thread(target=tmain, args=(cond,))
|
||||
t.start()
|
||||
|
||||
async with cond:
|
||||
self.assertTrue(await cond.wait(30))
|
||||
t.join()
|
||||
|
||||
async def test_multiple_loops_notify_all(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
results = []
|
||||
|
||||
def tmain(cond, results):
|
||||
async def atmain(cond, results):
|
||||
await asyncio.sleep(1)
|
||||
async with cond:
|
||||
res = await cond.wait(30)
|
||||
results.append(res)
|
||||
|
||||
asyncio.run(atmain(cond, results))
|
||||
|
||||
nthreads = 5
|
||||
threads = []
|
||||
for _ in range(nthreads):
|
||||
threads.append(threading.Thread(target=tmain, args=(cond, results)))
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
await asyncio.sleep(2)
|
||||
async with cond:
|
||||
cond.notify_all()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
self.assertEqual(results, [True] * nthreads)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -2389,7 +2389,9 @@ class TestMongoClientFailover(MockClientTest):
|
||||
|
||||
# But it can reconnect.
|
||||
c.revive_host("a:1")
|
||||
(c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
|
||||
(c._get_topology()).select_servers(
|
||||
writable_server_selector, _Op.TEST, server_selection_timeout=10
|
||||
)
|
||||
self.assertEqual(c.address, ("a", 1))
|
||||
|
||||
def _test_network_error(self, operation_callback):
|
||||
|
||||
@ -30,6 +30,7 @@ from test.utils import (
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
import pymongo
|
||||
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
|
||||
from pymongo.errors import (
|
||||
ClientBulkWriteException,
|
||||
@ -597,7 +598,9 @@ class TestClientBulkWriteCSOT(IntegrationTest):
|
||||
timeoutMS=2000,
|
||||
w="majority",
|
||||
)
|
||||
client.admin.command("ping") # Init the client first.
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(10):
|
||||
client.admin.command("ping")
|
||||
with self.assertRaises(ClientBulkWriteException) as context:
|
||||
client.bulk_write(models=models)
|
||||
self.assertIsInstance(context.exception.error, NetworkTimeout)
|
||||
|
||||
@ -1405,7 +1405,7 @@ class TestCursor(IntegrationTest):
|
||||
def test_to_list_csot_applied(self):
|
||||
client = self.single_client(timeoutMS=500)
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(2):
|
||||
with pymongo.timeout(10):
|
||||
client.admin.command("ping")
|
||||
coll = client.pymongo.test
|
||||
coll.insert_many([{} for _ in range(5)])
|
||||
@ -1447,7 +1447,7 @@ class TestCursor(IntegrationTest):
|
||||
def test_command_cursor_to_list_csot_applied(self):
|
||||
client = self.single_client(timeoutMS=500)
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(2):
|
||||
with pymongo.timeout(10):
|
||||
client.admin.command("ping")
|
||||
coll = client.pymongo.test
|
||||
coll.insert_many([{} for _ in range(5)])
|
||||
|
||||
@ -19,6 +19,7 @@ import os
|
||||
import threading
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
OvertCommandListener,
|
||||
SpecTestCreator,
|
||||
get_pool,
|
||||
@ -27,6 +28,7 @@ from test.utils import (
|
||||
from test.utils_selection_tests import create_topology
|
||||
|
||||
from pymongo.common import clean_node
|
||||
from pymongo.monitoring import ConnectionReadyEvent
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
|
||||
@ -131,19 +133,20 @@ class TestProse(IntegrationTest):
|
||||
@client_context.require_multiple_mongoses
|
||||
def test_load_balancing(self):
|
||||
listener = OvertCommandListener()
|
||||
cmap_listener = CMAPListener()
|
||||
# PYTHON-2584: Use a large localThresholdMS to avoid the impact of
|
||||
# varying RTTs.
|
||||
client = self.rs_client(
|
||||
client_context.mongos_seeds(),
|
||||
appName="loadBalancingTest",
|
||||
event_listeners=[listener],
|
||||
event_listeners=[listener, cmap_listener],
|
||||
localThresholdMS=30000,
|
||||
minPoolSize=10,
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
wait_until(lambda: len(client.nodes) == 2, "discover both nodes")
|
||||
wait_until(lambda: len(get_pool(client).conns) >= 10, "create 10 connections")
|
||||
# Delay find commands on
|
||||
# Wait for both pools to be populated.
|
||||
cmap_listener.wait_for_event(ConnectionReadyEvent, 20)
|
||||
# Delay find commands on only one mongos.
|
||||
delay_finds = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 10000},
|
||||
@ -161,7 +164,7 @@ class TestProse(IntegrationTest):
|
||||
freqs = self.frequencies(client, listener)
|
||||
self.assertLessEqual(freqs[delayed_server], 0.25)
|
||||
listener.reset()
|
||||
freqs = self.frequencies(client, listener, n_finds=100)
|
||||
freqs = self.frequencies(client, listener, n_finds=150)
|
||||
self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15)
|
||||
|
||||
|
||||
|
||||
@ -145,7 +145,17 @@ gridfs_files = [
|
||||
_gridfs_base + f for f in listdir(_gridfs_base) if (Path(_gridfs_base) / f).is_file()
|
||||
]
|
||||
|
||||
test_files = [_test_base + f for f in listdir(_test_base) if (Path(_test_base) / f).is_file()]
|
||||
|
||||
def async_only_test(f: str) -> bool:
|
||||
"""Return True for async tests that should not be converted to sync."""
|
||||
return f in ["test_locks.py"]
|
||||
|
||||
|
||||
test_files = [
|
||||
_test_base + f
|
||||
for f in listdir(_test_base)
|
||||
if (Path(_test_base) / f).is_file() and not async_only_test(f)
|
||||
]
|
||||
|
||||
sync_files = [
|
||||
_pymongo_dest_base + f
|
||||
@ -242,7 +252,7 @@ def translate_locks(lines: list[str]) -> list[str]:
|
||||
lock_lines = [line for line in lines if "_Lock(" in line]
|
||||
cond_lines = [line for line in lines if "_Condition(" in line]
|
||||
for line in lock_lines:
|
||||
res = re.search(r"_Lock\(([^()]*\(\))\)", line)
|
||||
res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line)
|
||||
if res:
|
||||
old = res[0]
|
||||
index = lines.index(line)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user