PYTHON-3406 Refactor fork tests to print traceback on failure (#1042)

This commit is contained in:
Shane Harvey 2022-08-18 17:06:02 -07:00 committed by GitHub
parent a0a5c7194d
commit 7f19186cac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 90 deletions

View File

@ -43,7 +43,7 @@ except ImportError:
from contextlib import contextmanager
from functools import wraps
from test.version import Version
from typing import Dict, no_type_check
from typing import Dict, Generator, no_type_check
from unittest import SkipTest
from urllib.parse import quote_plus
@ -998,6 +998,33 @@ class PyMongoTestCase(unittest.TestCase):
"configureFailPoint", cmd_on["configureFailPoint"], mode="off"
)
@contextmanager
def fork(self) -> Generator[int, None, None]:
"""Helper for tests that use os.fork()
Use in a with statement:
with self.fork() as pid:
if pid == 0: # Child
pass
else: # Parent
pass
"""
pid = os.fork()
in_child = pid == 0
try:
yield pid
except:
if in_child:
traceback.print_exc()
os._exit(1)
raise
finally:
if in_child:
os._exit(0)
# In parent, assert child succeeded.
self.assertEqual(0, os.waitpid(pid, 0)[1])
class IntegrationTest(PyMongoTestCase):
"""Base class for TestCases that need a connection to MongoDB to pass."""

View File

@ -330,7 +330,6 @@ class TestClientSimple(EncryptionIntegrationTest):
with self.assertRaisesRegex(InvalidOperation, "Cannot use MongoClient after close"):
client.admin.command("ping")
# Not available for versions of Python without "register_at_fork"
@unittest.skipIf(
not hasattr(os, "register_at_fork"),
"register_at_fork not available in this version of Python",
@ -342,14 +341,7 @@ class TestClientSimple(EncryptionIntegrationTest):
def test_fork(self):
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
client = rs_or_single_client(auto_encryption_opts=opts)
lock_pid = os.fork()
if lock_pid == 0:
client.admin.command("ping")
client.close()
os._exit(0)
else:
self.assertEqual(0, os.waitpid(lock_pid, 0)[1])
with self.fork():
client.admin.command("ping")
client.close()

View File

@ -16,7 +16,7 @@
import os
from multiprocessing import Pipe
from test import IntegrationTest, client_context
from test import IntegrationTest
from test.utils import (
ExceptionCatchingThread,
is_greenthread_patched,
@ -27,12 +27,6 @@ from unittest import skipIf
from bson.objectid import ObjectId
@client_context.require_connection
def setUpModule():
pass
# Not available for versions of Python without "register_at_fork"
@skipIf(
not hasattr(os, "register_at_fork"), "register_at_fork not available in this version of Python"
)
@ -42,83 +36,52 @@ def setUpModule():
)
class TestFork(IntegrationTest):
def test_lock_client(self):
"""
Forks the client with some items locked.
Parent => All locks should be as before the fork.
Child => All locks should be reset.
"""
def exit_cond():
self.client.admin.command("ping")
return 0
# Forks the client with some items locked.
# Parent => All locks should be as before the fork.
# Child => All locks should be reset.
with self.client._MongoClient__lock:
# Call _get_topology, will launch a thread to fork upon __enter__ing
# the with region.
lock_pid = os.fork()
# The POSIX standard states only the forking thread is cloned.
# In the parent, it'll return here.
# In the child, it'll end with the calling thread.
if lock_pid == 0:
code = -1
try:
code = exit_cond()
finally:
os._exit(code)
else:
self.assertEqual(0, os.waitpid(lock_pid, 0)[1])
with self.fork() as pid:
if pid == 0: # Child
self.client.admin.command("ping")
self.client.admin.command("ping")
def test_lock_object_id(self):
"""
Forks the client with ObjectId's _inc_lock locked.
Parent => _inc_lock should remain locked.
Child => _inc_lock should be unlocked.
"""
# Forks the client with ObjectId's _inc_lock locked.
# Parent => _inc_lock should remain locked.
# Child => _inc_lock should be unlocked.
with ObjectId._inc_lock:
lock_pid: int = os.fork()
if lock_pid == 0:
code = -1
try:
code = int(ObjectId._inc_lock.locked())
finally:
os._exit(code)
else:
self.assertEqual(0, os.waitpid(lock_pid, 0)[1])
with self.fork() as pid:
if pid == 0: # Child
self.assertFalse(ObjectId._inc_lock.locked())
self.assertTrue(ObjectId())
def test_topology_reset(self):
"""
Tests that topologies are different from each other.
Cannot use ID because virtual memory addresses may be the same.
Cannot reinstantiate ObjectId in the topology settings.
Relies on difference in PID when opened again.
"""
# Tests that topologies are different from each other.
# Cannot use ID because virtual memory addresses may be the same.
# Cannot reinstantiate ObjectId in the topology settings.
# Relies on difference in PID when opened again.
parent_conn, child_conn = Pipe()
init_id = self.client._topology._pid
parent_cursor_exc = self.client._kill_cursors_executor
lock_pid: int = os.fork()
if lock_pid == 0: # Child
self.client.admin.command("ping")
child_conn.send(self.client._topology._pid)
child_conn.send(
(
parent_cursor_exc != self.client._kill_cursors_executor,
"client._kill_cursors_executor was not reinitialized",
with self.fork() as pid:
if pid == 0: # Child
self.client.admin.command("ping")
child_conn.send(self.client._topology._pid)
child_conn.send(
(
parent_cursor_exc != self.client._kill_cursors_executor,
"client._kill_cursors_executor was not reinitialized",
)
)
)
os._exit(0)
else: # Parent
self.assertEqual(0, os.waitpid(lock_pid, 0)[1])
self.assertEqual(self.client._topology._pid, init_id)
child_id = parent_conn.recv()
self.assertNotEqual(child_id, init_id)
passed, msg = parent_conn.recv()
self.assertTrue(passed, msg)
else: # Parent
self.assertEqual(self.client._topology._pid, init_id)
child_id = parent_conn.recv()
self.assertNotEqual(child_id, init_id)
passed, msg = parent_conn.recv()
self.assertTrue(passed, msg)
def test_many_threaded(self):
# Fork randomly while doing operations.
clients = []
for _ in range(10):
c = rs_or_single_client()
@ -143,17 +106,10 @@ class TestFork(IntegrationTest):
rc = self.clients[i % len(self.clients)]
if i % 50 == 0 and self.fork:
# Fork
pid = os.fork()
if pid == 0:
code = -1
try:
with self.runner.fork() as pid:
if pid == 0: # Child
for c in self.clients:
action(c)
code = 0
finally:
os._exit(code)
else:
self.runner.assertEqual(0, os.waitpid(pid, 0)[1])
action(rc)
threads = [ForkThread(self, clients) for _ in range(10)]