From b737b843e974d9524fdbfbc8d18e0004b7743715 Mon Sep 17 00:00:00 2001 From: Julius Park Date: Tue, 1 Mar 2022 15:44:05 -0800 Subject: [PATCH] PYTHON-2956 Drivers should check out an implicit session only after checking out a connection (#876) --- pymongo/client_session.py | 23 ++++++++++++- pymongo/mongo_client.py | 9 ++++- pymongo/topology.py | 16 +++++---- test/test_session.py | 69 +++++++++++++++++++++++++++++++++++++-- 4 files changed, 106 insertions(+), 11 deletions(-) diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 4cf41b2c7..20d36fb06 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -947,9 +947,16 @@ class ClientSession(Generic[_DocumentType]): return self._transaction.opts.read_preference return None + def _materialize(self): + if isinstance(self._server_session, _EmptyServerSession): + old = self._server_session + self._server_session = self._client._topology.get_server_session() + if old.started_retryable_write: + self._server_session.inc_transaction_id() + def _apply_to(self, command, is_retryable, read_preference, sock_info): self._check_ended() - + self._materialize() if self.options.snapshot: self._update_read_concern(command, sock_info) @@ -1000,6 +1007,20 @@ class ClientSession(Generic[_DocumentType]): raise TypeError("A ClientSession cannot be copied, create a new session instead") +class _EmptyServerSession: + __slots__ = "dirty", "started_retryable_write" + + def __init__(self): + self.dirty = False + self.started_retryable_write = False + + def mark_dirty(self): + self.dirty = True + + def inc_transaction_id(self): + self.started_retryable_write = True + + class _ServerSession(object): def __init__(self, generation): # Ensure id is type 4, regardless of CodecOptions.uuid_representation. diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 4965b5e43..4ac4a5ba8 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -66,6 +66,7 @@ from pymongo import ( ) from pymongo.change_stream import ChangeStream, ClusterChangeStream from pymongo.client_options import ClientOptions +from pymongo.client_session import _EmptyServerSession from pymongo.command_cursor import CommandCursor from pymongo.errors import ( AutoReconnect, @@ -1601,7 +1602,11 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): def __start_session(self, implicit, **kwargs): # Raises ConfigurationError if sessions are not supported. - server_session = self._get_server_session() + if implicit: + self._topology._check_implicit_session_support() + server_session = _EmptyServerSession() + else: + server_session = self._get_server_session() opts = client_session.SessionOptions(**kwargs) return client_session.ClientSession(self, server_session, opts, implicit) @@ -1641,6 +1646,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): def _return_server_session(self, server_session, lock): """Internal: return a _ServerSession to the pool.""" + if isinstance(server_session, _EmptyServerSession): + return return self._topology.return_server_session(server_session, lock) def _ensure_session(self, session=None): diff --git a/pymongo/topology.py b/pymongo/topology.py index 6134b8201..03e0d4ee1 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -514,8 +514,15 @@ class Topology(object): with self._lock: return self._session_pool.pop_all() + def _check_implicit_session_support(self): + with self._lock: + self._check_session_support() + def _check_session_support(self): - """Internal check for session support on non-load balanced clusters.""" + """Internal check for session support on clusters.""" + if self._settings.load_balanced: + # Sessions never time out in load balanced mode. + return float("inf") session_timeout = self._description.logical_session_timeout_minutes if session_timeout is None: # Maybe we need an initial scan? Can raise ServerSelectionError. @@ -537,12 +544,7 @@ class Topology(object): def get_server_session(self): """Start or resume a server session, or raise ConfigurationError.""" with self._lock: - # Sessions are always supported in load balanced mode. - if not self._settings.load_balanced: - session_timeout = self._check_session_support() - else: - # Sessions never time out in load balanced mode. - session_timeout = float("inf") + session_timeout = self._check_session_support() return self._session_pool.get_server_session(session_timeout) def return_server_session(self, server_session, lock): diff --git a/test/test_session.py b/test/test_session.py index ec39bb241..53609c70c 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -18,20 +18,28 @@ import copy import sys import time from io import BytesIO -from typing import Set +from typing import Any, Callable, List, Set, Tuple from pymongo.mongo_client import MongoClient sys.path[0:0] = [""] from test import IntegrationTest, SkipTest, client_context, unittest -from test.utils import EventListener, rs_or_single_client, wait_until +from test.utils import ( + EventListener, + ExceptionCatchingThread, + rs_or_single_client, + wait_until, +) from bson import DBRef from gridfs import GridFS, GridFSBucket from pymongo import ASCENDING, IndexModel, InsertOne, monitoring +from pymongo.command_cursor import CommandCursor from pymongo.common import _MAX_END_SESSIONS +from pymongo.cursor import Cursor from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure +from pymongo.operations import UpdateOne from pymongo.read_concern import ReadConcern @@ -171,6 +179,63 @@ class TestSession(IntegrationTest): "%s did not return implicit session to pool" % (f.__name__,), ) + def test_implicit_sessions_checkout(self): + # "To confirm that implicit sessions only allocate their server session after a + # successful connection checkout" test from Driver Sessions Spec. + succeeded = False + failures = 0 + for _ in range(5): + listener = EventListener() + client = rs_or_single_client( + event_listeners=[listener], maxPoolSize=1, retryWrites=True + ) + cursor = client.db.test.find({}) + ops: List[Tuple[Callable, List[Any]]] = [ + (client.db.test.find_one, [{"_id": 1}]), + (client.db.test.delete_one, [{}]), + (client.db.test.update_one, [{}, {"$set": {"x": 2}}]), + (client.db.test.bulk_write, [[UpdateOne({}, {"$set": {"x": 2}})]]), + (client.db.test.find_one_and_delete, [{}]), + (client.db.test.find_one_and_update, [{}, {"$set": {"x": 1}}]), + (client.db.test.find_one_and_replace, [{}, {}]), + (client.db.test.aggregate, [[{"$limit": 1}]]), + (client.db.test.find, []), + (client.server_info, [{}]), + (client.db.aggregate, [[{"$listLocalSessions": {}}, {"$limit": 1}]]), + (cursor.distinct, ["_id"]), + (client.db.list_collections, []), + ] + threads = [] + listener.results.clear() + + def thread_target(op, *args): + res = op(*args) + if isinstance(res, (Cursor, CommandCursor)): + list(res) + + for op, args in ops: + threads.append( + ExceptionCatchingThread( + target=thread_target, args=[op, *args], name=op.__name__ + ) + ) + threads[-1].start() + self.assertEqual(len(threads), len(ops)) + for thread in threads: + thread.join() + self.assertIsNone(thread.exc) + client.close() + lsid_set = set() + for i in listener.results["started"]: + if i.command.get("lsid"): + lsid_set.add(i.command.get("lsid")["id"]) + if len(lsid_set) == 1: + succeeded = True + else: + failures += 1 + print(failures) + self.assertTrue(succeeded) + def test_pool_lifo(self): # "Pool is LIFO" test from Driver Sessions Spec. a = self.client.start_session()