diff --git a/pymongo/cluster.py b/pymongo/cluster.py index 0be03a1f9..aee9c3fd3 100644 --- a/pymongo/cluster.py +++ b/pymongo/cluster.py @@ -14,6 +14,7 @@ """Internal classes to monitor clusters of one or more servers.""" +import random import threading import time @@ -83,6 +84,7 @@ class Cluster(object): # ClusterDescription's stringification in exception msg. raise ConnectionFailure("No suitable servers available") + self._ensure_opened() self._request_check_all() # Release the lock and wait for the cluster description to @@ -96,6 +98,10 @@ class Cluster(object): return [self.get_server_by_address(sd.address) for sd in server_descriptions] + def select_server(self, selector, server_wait_time=None): + """Like select_servers, but choose a random server if several match.""" + return random.choice(self.select_servers(selector, server_wait_time)) + def on_change(self, server_description): """Process a new ServerDescription after an ismaster call completes.""" # We do no I/O holding the lock. @@ -140,6 +146,12 @@ class Cluster(object): def description(self): return self._description + def _ensure_opened(self): + """Start monitors. Hold the lock when calling this.""" + if not self._opened: + self._opened = True + self._update_servers() + def _request_check_all(self): """Wake all monitors. Hold the lock when calling this.""" for server in self._servers.values(): diff --git a/test/test_cluster.py b/test/test_cluster.py index 53be9b16b..e29f236ec 100644 --- a/test/test_cluster.py +++ b/test/test_cluster.py @@ -172,7 +172,7 @@ class TestSingleServerCluster(unittest.TestCase): # No matter whether the server is writable, # select_servers() returns it. - s = c.select_servers(writable_server_selector)[0] + s = c.select_server(writable_server_selector) self.assertEqual(server_type, s.description.server_type) def test_reopen(self): @@ -192,7 +192,7 @@ class TestSingleServerCluster(unittest.TestCase): return IsMaster({'ok': 1}), round_trip_time c = create_mock_cluster(monitor_class=TestMonitor) - s = c.select_servers(writable_server_selector)[0] + s = c.select_server(writable_server_selector) self.assertEqual(1, s.description.round_trip_time) round_trip_time = 3 @@ -532,7 +532,7 @@ class TestMultiServerCluster(unittest.TestCase): c = create_mock_cluster(seeds=['a', 'b'], set_name='rs') def write_batch_size(): - s = c.select_servers(writable_server_selector)[0] + s = c.select_server(writable_server_selector) return s.description.max_write_batch_size got_ismaster(c, ('a', 27017), { @@ -581,7 +581,7 @@ class TestClusterErrors(unittest.TestCase): c = create_mock_cluster(monitor_class=TestMonitor) # Await first ismaster call. - s = c.select_servers(writable_server_selector)[0] + s = c.select_server(writable_server_selector) self.assertEqual(1, ismaster_count[0]) pool_id = s.pool.pool_id @@ -604,7 +604,7 @@ class TestClusterErrors(unittest.TestCase): c = create_mock_cluster(monitor_class=TestMonitor) # Await first ismaster call. - s = c.select_servers(writable_server_selector)[0] + s = c.select_server(writable_server_selector) self.assertEqual(1, ismaster_count[0]) self.assertEqual(SERVER_TYPE.Standalone, s.description.server_type)