diff --git a/test/__init__.py b/test/__init__.py index 2a3e59adf..a3e1ca734 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -43,7 +43,7 @@ except ImportError: from contextlib import contextmanager from functools import wraps from test.version import Version -from typing import Dict, no_type_check +from typing import Dict, Generator, no_type_check from unittest import SkipTest from urllib.parse import quote_plus @@ -998,6 +998,33 @@ class PyMongoTestCase(unittest.TestCase): "configureFailPoint", cmd_on["configureFailPoint"], mode="off" ) + @contextmanager + def fork(self) -> Generator[int, None, None]: + """Helper for tests that use os.fork() + + Use in a with statement: + + with self.fork() as pid: + if pid == 0: # Child + pass + else: # Parent + pass + """ + pid = os.fork() + in_child = pid == 0 + try: + yield pid + except: + if in_child: + traceback.print_exc() + os._exit(1) + raise + finally: + if in_child: + os._exit(0) + # In parent, assert child succeeded. + self.assertEqual(0, os.waitpid(pid, 0)[1]) + class IntegrationTest(PyMongoTestCase): """Base class for TestCases that need a connection to MongoDB to pass.""" diff --git a/test/test_encryption.py b/test/test_encryption.py index 4ed415d4d..414669570 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -330,7 +330,6 @@ class TestClientSimple(EncryptionIntegrationTest): with self.assertRaisesRegex(InvalidOperation, "Cannot use MongoClient after close"): client.admin.command("ping") - # Not available for versions of Python without "register_at_fork" @unittest.skipIf( not hasattr(os, "register_at_fork"), "register_at_fork not available in this version of Python", @@ -342,14 +341,7 @@ class TestClientSimple(EncryptionIntegrationTest): def test_fork(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = rs_or_single_client(auto_encryption_opts=opts) - - lock_pid = os.fork() - if lock_pid == 0: - client.admin.command("ping") - client.close() - os._exit(0) - else: - self.assertEqual(0, os.waitpid(lock_pid, 0)[1]) + with self.fork(): client.admin.command("ping") client.close() diff --git a/test/test_fork.py b/test/test_fork.py index df1f009e2..092ac434a 100644 --- a/test/test_fork.py +++ b/test/test_fork.py @@ -16,7 +16,7 @@ import os from multiprocessing import Pipe -from test import IntegrationTest, client_context +from test import IntegrationTest from test.utils import ( ExceptionCatchingThread, is_greenthread_patched, @@ -27,12 +27,6 @@ from unittest import skipIf from bson.objectid import ObjectId -@client_context.require_connection -def setUpModule(): - pass - - -# Not available for versions of Python without "register_at_fork" @skipIf( not hasattr(os, "register_at_fork"), "register_at_fork not available in this version of Python" ) @@ -42,83 +36,52 @@ def setUpModule(): ) class TestFork(IntegrationTest): def test_lock_client(self): - """ - Forks the client with some items locked. - Parent => All locks should be as before the fork. - Child => All locks should be reset. - """ - - def exit_cond(): - self.client.admin.command("ping") - return 0 - + # Forks the client with some items locked. + # Parent => All locks should be as before the fork. + # Child => All locks should be reset. with self.client._MongoClient__lock: - # Call _get_topology, will launch a thread to fork upon __enter__ing - # the with region. - lock_pid = os.fork() - # The POSIX standard states only the forking thread is cloned. - # In the parent, it'll return here. - # In the child, it'll end with the calling thread. - if lock_pid == 0: - code = -1 - try: - code = exit_cond() - finally: - os._exit(code) - else: - self.assertEqual(0, os.waitpid(lock_pid, 0)[1]) + with self.fork() as pid: + if pid == 0: # Child + self.client.admin.command("ping") + self.client.admin.command("ping") def test_lock_object_id(self): - """ - Forks the client with ObjectId's _inc_lock locked. - Parent => _inc_lock should remain locked. - Child => _inc_lock should be unlocked. - """ + # Forks the client with ObjectId's _inc_lock locked. + # Parent => _inc_lock should remain locked. + # Child => _inc_lock should be unlocked. with ObjectId._inc_lock: - lock_pid: int = os.fork() - - if lock_pid == 0: - code = -1 - try: - code = int(ObjectId._inc_lock.locked()) - finally: - os._exit(code) - else: - self.assertEqual(0, os.waitpid(lock_pid, 0)[1]) + with self.fork() as pid: + if pid == 0: # Child + self.assertFalse(ObjectId._inc_lock.locked()) + self.assertTrue(ObjectId()) def test_topology_reset(self): - """ - Tests that topologies are different from each other. - Cannot use ID because virtual memory addresses may be the same. - Cannot reinstantiate ObjectId in the topology settings. - Relies on difference in PID when opened again. - """ + # Tests that topologies are different from each other. + # Cannot use ID because virtual memory addresses may be the same. + # Cannot reinstantiate ObjectId in the topology settings. + # Relies on difference in PID when opened again. parent_conn, child_conn = Pipe() init_id = self.client._topology._pid parent_cursor_exc = self.client._kill_cursors_executor - lock_pid: int = os.fork() - - if lock_pid == 0: # Child - self.client.admin.command("ping") - child_conn.send(self.client._topology._pid) - child_conn.send( - ( - parent_cursor_exc != self.client._kill_cursors_executor, - "client._kill_cursors_executor was not reinitialized", + with self.fork() as pid: + if pid == 0: # Child + self.client.admin.command("ping") + child_conn.send(self.client._topology._pid) + child_conn.send( + ( + parent_cursor_exc != self.client._kill_cursors_executor, + "client._kill_cursors_executor was not reinitialized", + ) ) - ) - os._exit(0) - else: # Parent - self.assertEqual(0, os.waitpid(lock_pid, 0)[1]) - self.assertEqual(self.client._topology._pid, init_id) - child_id = parent_conn.recv() - self.assertNotEqual(child_id, init_id) - passed, msg = parent_conn.recv() - self.assertTrue(passed, msg) + else: # Parent + self.assertEqual(self.client._topology._pid, init_id) + child_id = parent_conn.recv() + self.assertNotEqual(child_id, init_id) + passed, msg = parent_conn.recv() + self.assertTrue(passed, msg) def test_many_threaded(self): # Fork randomly while doing operations. - clients = [] for _ in range(10): c = rs_or_single_client() @@ -143,17 +106,10 @@ class TestFork(IntegrationTest): rc = self.clients[i % len(self.clients)] if i % 50 == 0 and self.fork: # Fork - pid = os.fork() - if pid == 0: - code = -1 - try: + with self.runner.fork() as pid: + if pid == 0: # Child for c in self.clients: action(c) - code = 0 - finally: - os._exit(code) - else: - self.runner.assertEqual(0, os.waitpid(pid, 0)[1]) action(rc) threads = [ForkThread(self, clients) for _ in range(10)]