diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 7053f20e1..f7da8c278 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -16,9 +16,12 @@ from __future__ import annotations import os +import socketserver import sys import threading +from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent + sys.path[0:0] = [""] from test import IntegrationTest, unittest @@ -27,6 +30,7 @@ from test.unified_format import generate_test_classes from test.utils import ( CMAPListener, HeartbeatEventListener, + HeartbeatEventsListListener, assertion_context, client_context, get_pool, @@ -38,7 +42,7 @@ from test.utils import ( from unittest.mock import patch from bson import Timestamp, json_util -from pymongo import common, monitoring +from pymongo import MongoClient, common, monitoring from pymongo.errors import ( AutoReconnect, ConfigurationError, @@ -396,6 +400,44 @@ class TestServerMonitoringMode(IntegrationTest): self.assertIsNone(monitor._rtt_monitor._executor._thread) +class MockTCPHandler(socketserver.BaseRequestHandler): + def handle(self): + self.server.events.append("client connected") + if self.request.recv(1024).strip(): + self.server.events.append("client hello received") + self.request.close() + + +class TestHeartbeatStartOrdering(unittest.TestCase): + def start_server(self, events): + server = socketserver.TCPServer(("localhost", 9999), MockTCPHandler) + server.events = events + server.handle_request() + server.server_close() + + def test_heartbeat_start_ordering(self): + events = [] + listener = HeartbeatEventsListListener(events) + server_thread = threading.Thread(target=self.start_server, args=(events,)) + server_thread.start() + _c = MongoClient( + "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) + ) + server_thread.join() + listener.wait_for_event(ServerHeartbeatStartedEvent, 1) + listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + self.assertEqual( + events, + [ + "serverHeartbeatStartedEvent", + "client connected", + "client hello received", + "serverHeartbeatFailedEvent", + ], + ) + + # Generate unified tests. globals().update(generate_test_classes(os.path.join(SDAM_PATH, "unified"), module=__name__)) diff --git a/test/utils.py b/test/utils.py index fb0e52292..b514b7a08 100644 --- a/test/utils.py +++ b/test/utils.py @@ -281,6 +281,26 @@ class HeartbeatEventListener(BaseListener, monitoring.ServerHeartbeatListener): self.add_event(event) +class HeartbeatEventsListListener(HeartbeatEventListener): + """Listens to only server heartbeat events and publishes them to a provided list.""" + + def __init__(self, events): + super().__init__() + self.event_list = events + + def started(self, event): + self.add_event(event) + self.event_list.append("serverHeartbeatStartedEvent") + + def succeeded(self, event): + self.add_event(event) + self.event_list.append("serverHeartbeatSucceededEvent") + + def failed(self, event): + self.add_event(event) + self.event_list.append("serverHeartbeatFailedEvent") + + class MockConnection: def __init__(self): self.cancel_context = _CancellationContext()