From d0da78ae54b5bab81c18b2d666ca1b6e89be3b0c Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Fri, 1 Sep 2017 16:14:19 -0400 Subject: [PATCH] PYTHON-1332 - Pool server sessions Also check if the topology supports sessions, error otherwise. --- pymongo/client_session.py | 47 ++++++++++++++++-- pymongo/mongo_client.py | 8 +++ pymongo/topology.py | 86 +++++++++++++++++++++++---------- pymongo/topology_description.py | 6 +++ test/__init__.py | 6 +++ test/test_session.py | 26 +++++++++- 6 files changed, 149 insertions(+), 30 deletions(-) diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 7f0eaf14e..8cba8fed7 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -41,6 +41,7 @@ Classes ======= """ +import collections import uuid from bson.binary import Binary @@ -79,7 +80,8 @@ class ClientSession(object): else: self._options = SessionOptions() - self._server_session = _ServerSession() + # Raises ConfigurationError if sessions are not supported. + self._server_session = client._get_server_session() def end_session(self): """Finish this session. @@ -89,7 +91,10 @@ class ClientSession(object): :class:`~pymongo.collection.Collection`, or :class:`~pymongo.cursor.Cursor` after the session has ended. """ - self._has_ended = True + if not self._has_ended: + self._has_ended = True + self.client._return_server_session(self._server_session) + self._server_session = None def __enter__(self): return self @@ -112,7 +117,10 @@ class ClientSession(object): @property def session_id(self): """A BSON document, the opaque server session identifier.""" - return self._server_session.session_id + if self._server_session: + return self._server_session.session_id + + return None @property def has_ended(self): @@ -125,3 +133,36 @@ class _ServerSession(object): # Ensure id is type 4, regardless of CodecOptions.uuid_representation. self.session_id = {'id': Binary(uuid.uuid4().bytes, 4)} self.last_use = monotonic.time() + + def timed_out(self, session_timeout_minutes): + idle_seconds = monotonic.time() - self.last_use + + # Timed out if we have less than a minute to live. + return idle_seconds > (session_timeout_minutes - 1) * 60 + + +class _ServerSessionPool(collections.deque): + """Pool of _ServerSession objects. + + This class is not thread-safe, access it while holding the Topology lock. + """ + def get_server_session(self, session_timeout_minutes): + # The most recently used sessions are on the left. + while self: + s = self.popleft() + if not s.timed_out(session_timeout_minutes): + return s + + return _ServerSession() + + def return_server_session(self, server_session, session_timeout_minutes): + # Clear stale sessions. The least recently used are on the right. + while self: + if self[-1].timed_out(session_timeout_minutes): + self.pop() + else: + # The remaining sessions also haven't timed out. + break + + if not server_session.timed_out(session_timeout_minutes): + self.appendleft(server_session) diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index aca788f39..3bbe29c62 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1224,6 +1224,14 @@ class MongoClient(common.BaseObject): opts = client_session.SessionOptions(**kwargs) return client_session.ClientSession(self, opts) + def _get_server_session(self): + """Internal: start or resume a _ServerSession.""" + return self._topology.get_server_session() + + def _return_server_session(self, server_session): + """Internal: return a _ServerSession to the pool.""" + return self._topology.return_server_session(server_session) + def server_info(self): """Get information about the MongoDB server we're connected to.""" return self.admin.command("buildinfo", diff --git a/pymongo/topology.py b/pymongo/topology.py index 67571708c..a4f9bcfc6 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -32,7 +32,7 @@ from pymongo.pool import PoolOptions from pymongo.topology_description import (updated_topology_description, TOPOLOGY_TYPE, TopologyDescription) -from pymongo.errors import ServerSelectionTimeoutError +from pymongo.errors import ServerSelectionTimeoutError, ConfigurationError from pymongo.monotonic import time as _time from pymongo.server import Server from pymongo.server_selectors import (any_server_selector, @@ -40,6 +40,7 @@ from pymongo.server_selectors import (any_server_selector, secondary_server_selector, writable_server_selector, Selection) +from pymongo.client_session import _ServerSessionPool def process_events_queue(queue_ref): @@ -107,6 +108,7 @@ class Topology(object): self._condition = self._settings.condition_class(self._lock) self._servers = {} self._pid = None + self._session_pool = _ServerSessionPool() if self._publish_server or self._publish_tp: def target(): @@ -175,35 +177,41 @@ class Topology(object): server_timeout = server_selection_timeout with self._lock: - now = _time() - end_time = now + server_timeout - server_descriptions = self._description.apply_selector( - selector, address) - - while not server_descriptions: - # No suitable servers. - if server_timeout == 0 or now > end_time: - raise ServerSelectionTimeoutError( - self._error_message(selector)) - - self._ensure_opened() - self._request_check_all() - - # Release the lock and wait for the topology description to - # change, or for a timeout. We won't miss any changes that - # came after our most recent apply_selector call, since we've - # held the lock until now. - self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) - self._description.check_compatible() - now = _time() - server_descriptions = self._description.apply_selector( - selector, address) - - self._description.check_compatible() + server_descriptions = self._select_servers_loop( + selector, server_timeout, address) return [self.get_server_by_address(sd.address) for sd in server_descriptions] + def _select_servers_loop(self, selector, timeout, address): + """select_servers() guts. Hold the lock when calling this.""" + now = _time() + end_time = now + timeout + server_descriptions = self._description.apply_selector( + selector, address) + + while not server_descriptions: + # No suitable servers. + if timeout == 0 or now > end_time: + raise ServerSelectionTimeoutError( + self._error_message(selector)) + + self._ensure_opened() + self._request_check_all() + + # Release the lock and wait for the topology description to + # change, or for a timeout. We won't miss any changes that + # came after our most recent apply_selector call, since we've + # held the lock until now. + self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) + self._description.check_compatible() + now = _time() + server_descriptions = self._description.apply_selector( + selector, address) + + self._description.check_compatible() + return server_descriptions + def select_server(self, selector, server_selection_timeout=None, @@ -363,6 +371,32 @@ class Topology(object): def description(self): return self._description + def get_server_session(self): + """Start or resume a server session, or raise ConfigurationError.""" + with self._lock: + session_timeout = self._description.logical_session_timeout_minutes + if session_timeout is None: + # Maybe we need an initial scan? Can raise ServerSelectionError. + if not self.description.has_known_servers: + self._select_servers_loop( + any_server_selector, + self._settings.server_selection_timeout, + None) + + session_timeout = self._description.logical_session_timeout_minutes + if session_timeout is None: + raise ConfigurationError( + "Sessions are not supported by this MongoDB deployment") + + return self._session_pool.get_server_session(session_timeout) + + def return_server_session(self, server_session): + with self._lock: + session_timeout = self._description.logical_session_timeout_minutes + if session_timeout is not None: + self._session_pool.return_server_session(server_session, + session_timeout) + def _new_selection(self): """A Selection object, initially including all known servers. diff --git a/pymongo/topology_description.py b/pymongo/topology_description.py index efa3db9df..02b353c09 100644 --- a/pymongo/topology_description.py +++ b/pymongo/topology_description.py @@ -183,6 +183,12 @@ class TopologyDescription(object): return [s for s in self._server_descriptions.values() if s.is_server_type_known] + @property + def has_known_servers(self): + """Whether there are any Servers of types besides Unknown.""" + return any(s for s in self._server_descriptions.values() + if s.is_server_type_known) + @property def common_wire_version(self): """Minimum of all servers' max wire versions, or None.""" diff --git a/test/__init__.py b/test/__init__.py index 8f043e679..5465804f0 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -493,6 +493,12 @@ class IntegrationTest(unittest.TestCase): cls.db = cls.client.pymongo_test +# Use assertRaisesRegex if available, otherwise use Python 2.7's +# deprecated assertRaisesRegexp, with a 'p'. +if not hasattr(unittest.TestCase, 'assertRaisesRegex'): + IntegrationTest.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp + + class MockClientTest(unittest.TestCase): """Base class for TestCases that use MockClient. diff --git a/test/test_session.py b/test/test_session.py index f93416499..ddb7f8eaa 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -14,7 +14,7 @@ """Test the client_session module.""" -from pymongo.errors import InvalidOperation +from pymongo.errors import InvalidOperation, ConfigurationError from test import IntegrationTest, client_context from test.utils import ignore_deprecations, rs_or_single_client @@ -32,3 +32,27 @@ class TestSession(IntegrationTest): with self.assertRaises(InvalidOperation): client.start_session() + + @client_context.require_version_min(3, 5, 12) + def test_pool_lifo(self): + # "Pool is LIFO" test from Driver Sessions Spec. + a = self.client.start_session() + b = self.client.start_session() + a_id = a.session_id + b_id = b.session_id + a.end_session() + b.end_session() + + s = self.client.start_session() + self.assertEqual(b_id, s.session_id) + self.assertNotEqual(a_id, s.session_id) + + s = self.client.start_session() + self.assertEqual(a_id, s.session_id) + self.assertNotEqual(b_id, s.session_id) + + @client_context.require_version_max(3, 5, 10) + def test_sessions_not_supported(self): + with self.assertRaisesRegex( + ConfigurationError, "Sessions are not supported"): + self.client.start_session()