PYTHON-1332 - Pool server sessions

Also check if the topology supports sessions, error otherwise.
This commit is contained in:
A. Jesse Jiryu Davis 2017-09-01 16:14:19 -04:00
parent 3c49e8a0f7
commit d0da78ae54
6 changed files with 149 additions and 30 deletions

View File

@ -41,6 +41,7 @@ Classes
=======
"""
import collections
import uuid
from bson.binary import Binary
@ -79,7 +80,8 @@ class ClientSession(object):
else:
self._options = SessionOptions()
self._server_session = _ServerSession()
# Raises ConfigurationError if sessions are not supported.
self._server_session = client._get_server_session()
def end_session(self):
"""Finish this session.
@ -89,7 +91,10 @@ class ClientSession(object):
:class:`~pymongo.collection.Collection`, or
:class:`~pymongo.cursor.Cursor` after the session has ended.
"""
self._has_ended = True
if not self._has_ended:
self._has_ended = True
self.client._return_server_session(self._server_session)
self._server_session = None
def __enter__(self):
return self
@ -112,7 +117,10 @@ class ClientSession(object):
@property
def session_id(self):
"""A BSON document, the opaque server session identifier."""
return self._server_session.session_id
if self._server_session:
return self._server_session.session_id
return None
@property
def has_ended(self):
@ -125,3 +133,36 @@ class _ServerSession(object):
# Ensure id is type 4, regardless of CodecOptions.uuid_representation.
self.session_id = {'id': Binary(uuid.uuid4().bytes, 4)}
self.last_use = monotonic.time()
def timed_out(self, session_timeout_minutes):
idle_seconds = monotonic.time() - self.last_use
# Timed out if we have less than a minute to live.
return idle_seconds > (session_timeout_minutes - 1) * 60
class _ServerSessionPool(collections.deque):
"""Pool of _ServerSession objects.
This class is not thread-safe, access it while holding the Topology lock.
"""
def get_server_session(self, session_timeout_minutes):
# The most recently used sessions are on the left.
while self:
s = self.popleft()
if not s.timed_out(session_timeout_minutes):
return s
return _ServerSession()
def return_server_session(self, server_session, session_timeout_minutes):
# Clear stale sessions. The least recently used are on the right.
while self:
if self[-1].timed_out(session_timeout_minutes):
self.pop()
else:
# The remaining sessions also haven't timed out.
break
if not server_session.timed_out(session_timeout_minutes):
self.appendleft(server_session)

View File

@ -1224,6 +1224,14 @@ class MongoClient(common.BaseObject):
opts = client_session.SessionOptions(**kwargs)
return client_session.ClientSession(self, opts)
def _get_server_session(self):
"""Internal: start or resume a _ServerSession."""
return self._topology.get_server_session()
def _return_server_session(self, server_session):
"""Internal: return a _ServerSession to the pool."""
return self._topology.return_server_session(server_session)
def server_info(self):
"""Get information about the MongoDB server we're connected to."""
return self.admin.command("buildinfo",

View File

@ -32,7 +32,7 @@ from pymongo.pool import PoolOptions
from pymongo.topology_description import (updated_topology_description,
TOPOLOGY_TYPE,
TopologyDescription)
from pymongo.errors import ServerSelectionTimeoutError
from pymongo.errors import ServerSelectionTimeoutError, ConfigurationError
from pymongo.monotonic import time as _time
from pymongo.server import Server
from pymongo.server_selectors import (any_server_selector,
@ -40,6 +40,7 @@ from pymongo.server_selectors import (any_server_selector,
secondary_server_selector,
writable_server_selector,
Selection)
from pymongo.client_session import _ServerSessionPool
def process_events_queue(queue_ref):
@ -107,6 +108,7 @@ class Topology(object):
self._condition = self._settings.condition_class(self._lock)
self._servers = {}
self._pid = None
self._session_pool = _ServerSessionPool()
if self._publish_server or self._publish_tp:
def target():
@ -175,35 +177,41 @@ class Topology(object):
server_timeout = server_selection_timeout
with self._lock:
now = _time()
end_time = now + server_timeout
server_descriptions = self._description.apply_selector(
selector, address)
while not server_descriptions:
# No suitable servers.
if server_timeout == 0 or now > end_time:
raise ServerSelectionTimeoutError(
self._error_message(selector))
self._ensure_opened()
self._request_check_all()
# Release the lock and wait for the topology description to
# change, or for a timeout. We won't miss any changes that
# came after our most recent apply_selector call, since we've
# held the lock until now.
self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
self._description.check_compatible()
now = _time()
server_descriptions = self._description.apply_selector(
selector, address)
self._description.check_compatible()
server_descriptions = self._select_servers_loop(
selector, server_timeout, address)
return [self.get_server_by_address(sd.address)
for sd in server_descriptions]
def _select_servers_loop(self, selector, timeout, address):
"""select_servers() guts. Hold the lock when calling this."""
now = _time()
end_time = now + timeout
server_descriptions = self._description.apply_selector(
selector, address)
while not server_descriptions:
# No suitable servers.
if timeout == 0 or now > end_time:
raise ServerSelectionTimeoutError(
self._error_message(selector))
self._ensure_opened()
self._request_check_all()
# Release the lock and wait for the topology description to
# change, or for a timeout. We won't miss any changes that
# came after our most recent apply_selector call, since we've
# held the lock until now.
self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
self._description.check_compatible()
now = _time()
server_descriptions = self._description.apply_selector(
selector, address)
self._description.check_compatible()
return server_descriptions
def select_server(self,
selector,
server_selection_timeout=None,
@ -363,6 +371,32 @@ class Topology(object):
def description(self):
return self._description
def get_server_session(self):
"""Start or resume a server session, or raise ConfigurationError."""
with self._lock:
session_timeout = self._description.logical_session_timeout_minutes
if session_timeout is None:
# Maybe we need an initial scan? Can raise ServerSelectionError.
if not self.description.has_known_servers:
self._select_servers_loop(
any_server_selector,
self._settings.server_selection_timeout,
None)
session_timeout = self._description.logical_session_timeout_minutes
if session_timeout is None:
raise ConfigurationError(
"Sessions are not supported by this MongoDB deployment")
return self._session_pool.get_server_session(session_timeout)
def return_server_session(self, server_session):
with self._lock:
session_timeout = self._description.logical_session_timeout_minutes
if session_timeout is not None:
self._session_pool.return_server_session(server_session,
session_timeout)
def _new_selection(self):
"""A Selection object, initially including all known servers.

View File

@ -183,6 +183,12 @@ class TopologyDescription(object):
return [s for s in self._server_descriptions.values()
if s.is_server_type_known]
@property
def has_known_servers(self):
"""Whether there are any Servers of types besides Unknown."""
return any(s for s in self._server_descriptions.values()
if s.is_server_type_known)
@property
def common_wire_version(self):
"""Minimum of all servers' max wire versions, or None."""

View File

@ -493,6 +493,12 @@ class IntegrationTest(unittest.TestCase):
cls.db = cls.client.pymongo_test
# Use assertRaisesRegex if available, otherwise use Python 2.7's
# deprecated assertRaisesRegexp, with a 'p'.
if not hasattr(unittest.TestCase, 'assertRaisesRegex'):
IntegrationTest.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
class MockClientTest(unittest.TestCase):
"""Base class for TestCases that use MockClient.

View File

@ -14,7 +14,7 @@
"""Test the client_session module."""
from pymongo.errors import InvalidOperation
from pymongo.errors import InvalidOperation, ConfigurationError
from test import IntegrationTest, client_context
from test.utils import ignore_deprecations, rs_or_single_client
@ -32,3 +32,27 @@ class TestSession(IntegrationTest):
with self.assertRaises(InvalidOperation):
client.start_session()
@client_context.require_version_min(3, 5, 12)
def test_pool_lifo(self):
# "Pool is LIFO" test from Driver Sessions Spec.
a = self.client.start_session()
b = self.client.start_session()
a_id = a.session_id
b_id = b.session_id
a.end_session()
b.end_session()
s = self.client.start_session()
self.assertEqual(b_id, s.session_id)
self.assertNotEqual(a_id, s.session_id)
s = self.client.start_session()
self.assertEqual(a_id, s.session_id)
self.assertNotEqual(b_id, s.session_id)
@client_context.require_version_max(3, 5, 10)
def test_sessions_not_supported(self):
with self.assertRaisesRegex(
ConfigurationError, "Sessions are not supported"):
self.client.start_session()