From 4170dc958e2ac1a43d92fe0ea3bb8f22674cff0a Mon Sep 17 00:00:00 2001 From: Ben Warner Date: Tue, 16 Aug 2022 10:40:28 -0700 Subject: [PATCH] PYTHON-3393 Added fork-safety stress test. (#1036) --- test/test_fork.py | 64 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/test/test_fork.py b/test/test_fork.py index b1c98a26f..41ce16249 100644 --- a/test/test_fork.py +++ b/test/test_fork.py @@ -17,6 +17,7 @@ import os from multiprocessing import Pipe from test import IntegrationTest, client_context +from test.utils import ExceptionCatchingThread, rs_or_single_client from unittest import skipIf from bson.objectid import ObjectId @@ -51,7 +52,11 @@ class TestFork(IntegrationTest): # In the parent, it'll return here. # In the child, it'll end with the calling thread. if lock_pid == 0: - os._exit(exit_cond()) + code = -1 + try: + code = exit_cond() + finally: + os._exit(code) else: self.assertEqual(0, os.waitpid(lock_pid, 0)[1]) @@ -65,7 +70,11 @@ class TestFork(IntegrationTest): lock_pid: int = os.fork() if lock_pid == 0: - os._exit(int(ObjectId._inc_lock.locked())) + code = -1 + try: + code = int(ObjectId._inc_lock.locked()) + finally: + os._exit(code) else: self.assertEqual(0, os.waitpid(lock_pid, 0)[1]) @@ -98,3 +107,54 @@ class TestFork(IntegrationTest): self.assertNotEqual(child_id, init_id) passed, msg = parent_conn.recv() self.assertTrue(passed, msg) + + def test_many_threaded(self): + # Fork randomly while doing operations. + + clients = [] + for _ in range(10): + c = rs_or_single_client() + clients.append(c) + self.addCleanup(c.close) + + class ForkThread(ExceptionCatchingThread): + def __init__(self, runner, clients): + self.runner = runner + self.clients = clients + self.fork = False + + super().__init__(target=self.fork_behavior) + + def fork_behavior(self) -> None: + def action(client): + client.admin.command("ping") + return 0 + + for i in range(200): + # Pick a random client. + rc = self.clients[i % len(self.clients)] + if i % 50 == 0 and self.fork: + # Fork + pid = os.fork() + if pid == 0: + code = -1 + try: + for c in self.clients: + action(c) + code = 0 + finally: + os._exit(code) + else: + self.runner.assertEqual(0, os.waitpid(pid, 0)[1]) + action(rc) + + threads = [ForkThread(self, clients) for _ in range(10)] + threads[-1].fork = True + for t in threads: + t.start() + + for t in threads: + t.join() + + for c in clients: + c.close()