diff --git a/pymongo/network.py b/pymongo/network.py index 01dca0b83..df08158b2 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -244,6 +244,7 @@ def wait_for_read(sock_info, deadline): # Only Monitor connections can be cancelled. if context: sock = sock_info.sock + timed_out = False while True: # SSLSocket can have buffered data which won't be caught by select. if hasattr(sock, "pending") and sock.pending() > 0: @@ -252,7 +253,13 @@ def wait_for_read(sock_info, deadline): # Wait up to 500ms for the socket to become readable and then # check for cancellation. if deadline: - timeout = max(min(deadline - time.monotonic(), _POLL_TIMEOUT), 0.001) + remaining = deadline - time.monotonic() + # When the timeout has expired perform one final check to + # see if the socket is readable. This helps avoid spurious + # timeouts on AWS Lambda and other FaaS environments. + if remaining <= 0: + timed_out = True + timeout = max(min(remaining, _POLL_TIMEOUT), 0) else: timeout = _POLL_TIMEOUT readable = sock_info.socket_checker.select(sock, read=True, timeout=timeout) @@ -260,7 +267,7 @@ def wait_for_read(sock_info, deadline): raise _OperationCancelled("hello cancelled") if readable: return - if deadline and time.monotonic() > deadline: + if timed_out: raise socket.timeout("timed out") diff --git a/test/__init__.py b/test/__init__.py index c432b2609..ee6e3ca50 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -44,6 +44,7 @@ from functools import wraps from test.version import Version from typing import Dict, no_type_check from unittest import SkipTest +from urllib.parse import quote_plus import pymongo import pymongo.errors @@ -279,6 +280,22 @@ class ClientContext(object): opts["replicaSet"] = self.replica_set_name return opts + @property + def uri(self): + """Return the MongoClient URI for creating a duplicate client.""" + opts = client_context.default_client_options.copy() + opts_parts = [] + for opt, val in opts.items(): + strval = str(val) + if isinstance(val, bool): + strval = strval.lower() + opts_parts.append(f"{opt}={quote_plus(strval)}") + opts_part = "&".join(opts_parts) + auth_part = "" + if client_context.auth_enabled: + auth_part = f"{quote_plus(db_user)}:{quote_plus(db_pwd)}@" + return f"mongodb://{auth_part}{self.pair}/?{opts_part}" + @property def hello(self): if not self._hello: @@ -359,7 +376,7 @@ class ClientContext(object): username=db_user, password=db_pwd, replicaSet=self.replica_set_name, - **self.default_client_options + **self.default_client_options, ) # May not have this if OperationFailure was raised earlier. @@ -387,7 +404,7 @@ class ClientContext(object): username=db_user, password=db_pwd, replicaSet=self.replica_set_name, - **self.default_client_options + **self.default_client_options, ) else: self.client = pymongo.MongoClient( @@ -490,7 +507,7 @@ class ClientContext(object): username=db_user, password=db_pwd, serverSelectionTimeoutMS=100, - **self.default_client_options + **self.default_client_options, ) try: diff --git a/test/sigstop_sigcont.py b/test/sigstop_sigcont.py new file mode 100644 index 000000000..ef4730f0b --- /dev/null +++ b/test/sigstop_sigcont.py @@ -0,0 +1,85 @@ +# Copyright 2022-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Used by test_client.TestClient.test_sigstop_sigcont.""" + +import logging +import sys + +sys.path[0:0] = [""] + +from pymongo import monitoring +from pymongo.mongo_client import MongoClient + + +class HeartbeatLogger(monitoring.ServerHeartbeatListener): + """Log events until the listener is closed.""" + + def __init__(self): + self.closed = False + + def close(self): + self.closed = True + + def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None: + if self.closed: + return + logging.info("%s", event) + + def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None: + if self.closed: + return + logging.info("%s", event) + + def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None: + if self.closed: + return + logging.warning("%s", event) + + +def main(uri: str) -> None: + heartbeat_logger = HeartbeatLogger() + client = MongoClient( + uri, + event_listeners=[heartbeat_logger], + heartbeatFrequencyMS=500, + connectTimeoutMS=500, + ) + client.admin.command("ping") + logging.info("TEST STARTED") + # test_sigstop_sigcont will SIGSTOP and SIGCONT this process in this loop. + while True: + try: + data = input('Type "q" to quit: ') + except EOFError: + break + if data == "q": + break + client.admin.command("ping") + logging.info("TEST COMPLETED") + heartbeat_logger.close() + client.close() + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("unknown or missing options") + print(f"usage: python3 {sys.argv[0]} 'mongodb://localhost'") + exit(1) + + # Enable logs in this format: + # 2022-03-30 12:40:55,582 INFO + FORMAT = "%(asctime)s %(levelname)s %(message)s" + logging.basicConfig(format=FORMAT, level=logging.INFO) + main(sys.argv[1]) diff --git a/test/test_client.py b/test/test_client.py index 5958ff6d5..40f276a9d 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -23,6 +23,7 @@ import os import signal import socket import struct +import subprocess import sys import threading import time @@ -1688,6 +1689,39 @@ class TestClient(IntegrationTest): ) self.assertEqual(len(client.topology_description.server_descriptions()), 2) + @unittest.skipIf( + client_context.load_balancer or client_context.serverless, + "loadBalanced clients do not run SDAM", + ) + @unittest.skipIf(sys.platform == "win32", "Windows does not support SIGSTOP") + def test_sigstop_sigcont(self): + test_dir = os.path.dirname(os.path.realpath(__file__)) + script = os.path.join(test_dir, "sigstop_sigcont.py") + p = subprocess.Popen( + [sys.executable, script, client_context.uri], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + self.addCleanup(p.wait, timeout=1) + self.addCleanup(p.kill) + time.sleep(1) + # Stop the child, sleep for twice the streaming timeout + # (heartbeatFrequencyMS + connectTimeoutMS), and restart. + os.kill(p.pid, signal.SIGSTOP) + time.sleep(2) + os.kill(p.pid, signal.SIGCONT) + time.sleep(0.5) + # Tell the script to exit gracefully. + outs, _ = p.communicate(input=b"q\n", timeout=10) + self.assertTrue(outs) + log_output = outs.decode("utf-8") + self.assertIn("TEST STARTED", log_output) + self.assertIn("ServerHeartbeatStartedEvent", log_output) + self.assertIn("ServerHeartbeatSucceededEvent", log_output) + self.assertIn("TEST COMPLETED", log_output) + self.assertNotIn("ServerHeartbeatFailedEvent", log_output) + class TestExhaustCursor(IntegrationTest): """Test that clients properly handle errors from exhaust cursors."""